Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CostModel][AArch64] Make extractelement, with fmul user, free whenev… #111479

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sushgokh
Copy link
Contributor

@sushgokh sushgokh commented Oct 8, 2024

…er possible

In case of Neon, if there exists extractelement from lane != 0 such that

  1. extractelement does not necessitate a move from vector_reg -> GPR
  2. extractelement result feeds into fmul
  3. Other operand of fmul is a scalar or extractelement from lane 0 or lane equivalent to 0
    then the extractelement can be merged with fmul in the backend and it incurs no cost.

e.g.

define double @foo(<2 x double> %a) { 
  %1 = extractelement <2 x double> %a, i32 0 
  %2 = extractelement <2 x double> %a, i32 1
  %res = fmul double %1, %2    
  ret double %res
}

%2 and %res can be merged in the backend to generate:
fmul d0, d0, v0.d[1]

The change was tested with SPEC FP(C/C++) on Neoverse-v2.
Compile time impact: None
Performance impact: Observing 1.3-1.7% uplift on lbm benchmark with -flto depending upon the config.

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 8, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Sushant Gokhale (sushgokh)

Changes

…er possible

In case of Neon, if there exists extractelement from lane != 0 such that

  1. extractelement does not necessitate a move from vector_reg -> GPR
  2. extractelement result feeds into fmul
  3. Other operand of fmul is a scalar or extractelement from lane 0 or lane equivalent to 0
    then the extractelement can be merged with fmul in the backend and it incurs no cost.

e.g.

define double @<!-- -->foo(&lt;2 x double&gt; %a) { 
  %1 = extractelement &lt;2 x double&gt; %a, i32 0 
  %2 = extractelement &lt;2 x double&gt; %a, i32 1
  %res = fmul double %1, %2    
  ret double %res
}

%2 and %res can be merged in the backend to generate:
fmul d0, d0, v0.d[1]

The change was tested with SPEC FP(C/C++) on Neoverse-v2.
Compile time impact: None
Performance impact: Observing 1.3-1.7% uplift on lbm benchmark with -flto depending upon the config.


Patch is 31.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/111479.diff

9 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+32)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+10)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+14-3)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+17)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+159-4)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+19-2)
  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+14-2)
  • (modified) llvm/test/Analysis/CostModel/AArch64/extract_float.ll (+16-13)
  • (modified) llvm/test/Transforms/SLPVectorizer/consecutive-access.ll (+22-48)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 89a85bc8a90864..7a84c31679a888 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1392,6 +1392,19 @@ class TargetTransformInfo {
                                      unsigned Index = -1, Value *Op0 = nullptr,
                                      Value *Op1 = nullptr) const;
 
+  /// \return The expected cost of vector Insert and Extract.
+  /// Use -1 to indicate that there is no information on the index value.
+  /// This is used when the instruction is not available; a typical use
+  /// case is to provision the cost of vectorization/scalarization in
+  /// vectorizer passes.
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx) const;
+
   /// \return The expected cost of vector Insert and Extract.
   /// This is used when instruction is available, and implementation
   /// asserts 'I' is not nullptr.
@@ -2062,6 +2075,14 @@ class TargetTransformInfo::Concept {
                                              TTI::TargetCostKind CostKind,
                                              unsigned Index, Value *Op0,
                                              Value *Op1) = 0;
+
+  virtual InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx) = 0;
   virtual InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
                                              TTI::TargetCostKind CostKind,
                                              unsigned Index) = 0;
@@ -2726,6 +2747,17 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
                                      Value *Op1) override {
     return Impl.getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
   }
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx) override {
+    return Impl.getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1,
+                                   Scalar, ScalarAndIdxToUser,
+                                   UserToScalarAndIdx);
+  }
   InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
                                      TTI::TargetCostKind CostKind,
                                      unsigned Index) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 50040dc8f6165b..deeecc462e19c7 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -683,6 +683,16 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx) const {
+    return 1;
+  }
+
   InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
                                      TTI::TargetCostKind CostKind,
                                      unsigned Index) const {
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index c36a346c1b2e05..48bd91c5d8c6e5 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1277,12 +1277,23 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return 1;
   }
 
-  InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index, Value *Op0, Value *Op1) {
+  virtual InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
+                                             TTI::TargetCostKind CostKind,
+                                             unsigned Index, Value *Op0,
+                                             Value *Op1) {
     return getRegUsageForType(Val->getScalarType());
   }
 
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx) {
+    return getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
+  }
+
   InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
                                      TTI::TargetCostKind CostKind,
                                      unsigned Index) {
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index b5195f764cbd1c..7c588e717c4bba 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1037,6 +1037,23 @@ InstructionCost TargetTransformInfo::getVectorInstrCost(
   return Cost;
 }
 
+InstructionCost TargetTransformInfo::getVectorInstrCost(
+    unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+    Value *Op0, Value *Op1, Value *Scalar,
+    const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+        &ScalarAndIdxToUser,
+    const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+        &UserToScalarAndIdx) const {
+  // FIXME: Assert that Opcode is either InsertElement or ExtractElement.
+  // This is mentioned in the interface description and respected by all
+  // callers, but never asserted upon.
+  InstructionCost Cost = TTIImpl->getVectorInstrCost(
+      Opcode, Val, CostKind, Index, Op0, Op1, Scalar, ScalarAndIdxToUser,
+      UserToScalarAndIdx);
+  assert(Cost >= 0 && "TTI should not produce negative costs!");
+  return Cost;
+}
+
 InstructionCost
 TargetTransformInfo::getVectorInstrCost(const Instruction &I, Type *Val,
                                         TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 80d5168ae961ab..f86e7a3715b4e4 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -16,14 +16,20 @@
 #include "llvm/CodeGen/BasicTTIImpl.h"
 #include "llvm/CodeGen/CostTable.h"
 #include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsAArch64.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Transforms/InstCombine/InstCombiner.h"
 #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
 #include <algorithm>
+#include <cassert>
 #include <optional>
 using namespace llvm;
 using namespace llvm::PatternMatch;
@@ -3145,12 +3151,20 @@ InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
   return 0;
 }
 
-InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
-                                                         Type *Val,
-                                                         unsigned Index,
-                                                         bool HasRealUse) {
+InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
+    std::variant<const Instruction *, const unsigned> InstOrOpcode, Type *Val,
+    unsigned Index, bool HasRealUse, Value *Scalar,
+    const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+        &ScalarAndIdxToUser,
+    const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+        &UserToScalarAndIdx) {
   assert(Val->isVectorTy() && "This must be a vector type");
 
+  const Instruction *I =
+      (std::holds_alternative<const Instruction *>(InstOrOpcode)
+           ? get<const Instruction *>(InstOrOpcode)
+           : nullptr);
+
   if (Index != -1U) {
     // Legalize the type.
     std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val);
@@ -3194,6 +3208,134 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
     // compile-time considerations.
   }
 
+  // In case of Neon, if there exists extractelement from lane != 0 such that
+  // 1. extractelement does not necessitate a move from vector_reg -> GPR.
+  // 2. extractelement result feeds into fmul.
+  // 3. Other operand of fmul is a scalar or extractelement from lane 0 or lane
+  // equivalent to 0.
+  // then the extractelement can be merged with fmul in the backend and it
+  // incurs no cost.
+  // e.g.
+  // define double @foo(<2 x double> %a) {
+  //   %1 = extractelement <2 x double> %a, i32 0
+  //   %2 = extractelement <2 x double> %a, i32 1
+  //   %res = fmul double %1, %2
+  //   ret double %res
+  // }
+  // %2 and %res can be merged in the backend to generate fmul v0, v0, v1.d[1]
+  auto ExtractCanFuseWithFmul = [&]() {
+    // We bail out if the extract is from lane 0.
+    if (Index == 0)
+      return false;
+
+    // Check if the scalar element type of the vector operand of ExtractElement
+    // instruction is one of the allowed types.
+    auto IsAllowedScalarTy = [&](const Type *T) {
+      return T->isFloatTy() || T->isDoubleTy() ||
+             (T->isHalfTy() && ST->hasFullFP16());
+    };
+
+    // Check if the extractelement user is scalar fmul.
+    auto IsUserFMulScalarTy = [](const Value *EEUser) {
+      // Check if the user is scalar fmul.
+      const BinaryOperator *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
+      return BO && BO->getOpcode() == BinaryOperator::FMul &&
+             !BO->getType()->isVectorTy();
+    };
+
+    // InstCombine combines fmul with fadd/fsub. Hence, extractelement fusion
+    // with fmul does not happen.
+    auto IsFMulUserFAddFSub = [](const Value *FMul) {
+      return any_of(FMul->users(), [](const User *U) {
+        const BinaryOperator *BO = dyn_cast_if_present<BinaryOperator>(U);
+        return (BO && (BO->getOpcode() == BinaryOperator::FAdd ||
+                       BO->getOpcode() == BinaryOperator::FSub));
+      });
+    };
+
+    // Check if the type constraints on input vector type and result scalar type
+    // of extractelement instruction are satisfied.
+    auto TypeConstraintsOnEESatisfied =
+        [&IsAllowedScalarTy](const Type *VectorTy, const Type *ScalarTy) {
+          return isa<FixedVectorType>(VectorTy) && IsAllowedScalarTy(ScalarTy);
+        };
+
+    // Check if the extract index is from lane 0 or lane equivalent to 0 for a
+    // certain scalar type and a certain vector register width.
+    auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
+                                             const unsigned &EltSz) {
+      auto RegWidth =
+          getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
+              .getFixedValue();
+      return (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
+    };
+
+    if (std::holds_alternative<const unsigned>(InstOrOpcode)) {
+      if (!TypeConstraintsOnEESatisfied(Val, Val->getScalarType()))
+        return false;
+      const auto &ScalarIdxPair = std::make_pair(Scalar, Index);
+      return ScalarAndIdxToUser.find(ScalarIdxPair) !=
+                 ScalarAndIdxToUser.end() &&
+             all_of(ScalarAndIdxToUser.at(ScalarIdxPair), [&](Value *U) {
+               if (!IsUserFMulScalarTy(U) || IsFMulUserFAddFSub(U))
+                 return false;
+               // 1. Check if the other operand is extract from lane 0 or lane
+               // equivalent to 0.
+               // 2. In case of SLP, if the other operand is not extract from
+               // same tree, we bail out since we can not analyze that extract.
+               return UserToScalarAndIdx.at(U).size() == 2 &&
+                      all_of(UserToScalarAndIdx.at(U), [&](auto &P) {
+                        if (ScalarIdxPair == P)
+                          return true; // Skip.
+                        return IsExtractLaneEquivalentToZero(
+                            P.second, Val->getScalarSizeInBits());
+                      });
+             });
+    } else {
+      const ExtractElementInst *EE = cast<ExtractElementInst>(I);
+
+      const ConstantInt *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand());
+      if (!IdxOp)
+        return false;
+
+      if (!TypeConstraintsOnEESatisfied(EE->getVectorOperand()->getType(),
+                                        EE->getType()))
+        return false;
+
+      return !EE->users().empty() && all_of(EE->users(), [&](const User *U) {
+        if (!IsUserFMulScalarTy(U) || IsFMulUserFAddFSub(U))
+          return false;
+
+        // Check if the other operand of extractelement is also extractelement
+        // from lane equivalent to 0.
+        const BinaryOperator *BO = cast<BinaryOperator>(U);
+        const ExtractElementInst *OtherEE = dyn_cast<ExtractElementInst>(
+            BO->getOperand(0) == EE ? BO->getOperand(1) : BO->getOperand(0));
+        if (OtherEE) {
+          const ConstantInt *IdxOp =
+              dyn_cast<ConstantInt>(OtherEE->getIndexOperand());
+          if (!IdxOp)
+            return false;
+          return IsExtractLaneEquivalentToZero(
+              cast<ConstantInt>(OtherEE->getIndexOperand())
+                  ->getValue()
+                  .getZExtValue(),
+              OtherEE->getType()->getScalarSizeInBits());
+        }
+        return true;
+      });
+    }
+    return false;
+  };
+
+  if (std::holds_alternative<const unsigned>(InstOrOpcode)) {
+    const unsigned &Opcode = get<const unsigned>(InstOrOpcode);
+    if (Opcode == Instruction::ExtractElement && ExtractCanFuseWithFmul())
+      return 0;
+  } else if (I && I->getOpcode() == Instruction::ExtractElement &&
+             ExtractCanFuseWithFmul())
+    return 0;
+
   // All other insert/extracts cost this much.
   return ST->getVectorInsertExtractBaseCost();
 }
@@ -3207,6 +3349,19 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
   return getVectorInstrCostHelper(nullptr, Val, Index, HasRealUse);
 }
 
+InstructionCost AArch64TTIImpl::getVectorInstrCost(
+    unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+    Value *Op0, Value *Op1, Value *Scalar,
+    const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+        &ScalarAndIdxToUser,
+    const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+        &UserToScalarAndIdx) {
+  bool HasRealUse =
+      Opcode == Instruction::InsertElement && Op0 && !isa<UndefValue>(Op0);
+  return getVectorInstrCostHelper(Opcode, Val, Index, HasRealUse, Scalar,
+                                  ScalarAndIdxToUser, UserToScalarAndIdx);
+}
+
 InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
                                                    Type *Val,
                                                    TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 28e45207596ecd..12952ba41708b1 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -24,6 +24,7 @@
 #include "llvm/CodeGen/BasicTTIImpl.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Intrinsics.h"
+#include <climits>
 #include <cstdint>
 #include <optional>
 
@@ -66,8 +67,15 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   // 'Val' and 'Index' are forwarded from 'getVectorInstrCost'; 'HasRealUse'
   // indicates whether the vector instruction is available in the input IR or
   // just imaginary in vectorizer passes.
-  InstructionCost getVectorInstrCostHelper(const Instruction *I, Type *Val,
-                                           unsigned Index, bool HasRealUse);
+  InstructionCost getVectorInstrCostHelper(
+      std::variant<const Instruction *, const unsigned> InstOrOpcode, Type *Val,
+      unsigned Index, bool HasRealUse, Value *Scalar = nullptr,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser =
+              DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>(),
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx = DenseMap<
+              Value *, SmallVector<std::pair<Value *, unsigned>, 4>>());
 
 public:
   explicit AArch64TTIImpl(const AArch64TargetMachine *TM, const Function &F)
@@ -185,6 +193,15 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
                                      TTI::TargetCostKind CostKind,
                                      unsigned Index, Value *Op0, Value *Op1);
+
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx);
+
   InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
                                      TTI::TargetCostKind CostKind,
                                      unsigned Index);
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 401597af35bdac..90a070fc3b3c48 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -11633,6 +11633,17 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
   std::optional<DenseMap<Value *, unsigned>> ValueToExtUses;
   DenseMap<const TreeEntry *, DenseSet<Value *>> ExtractsCount;
   SmallPtrSet<Value *, 4> ScalarOpsFromCasts;
+  // Keep track of {Scalar, Index} -> User and User -> {Scalar, Index}.
+  // On AArch64, this helps in fusing a mov instruction, associated with
+  // extractelement, with fmul in the backend so that extractelement is free.
+  DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+      ScalarAndIdxToUser;
+  DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+      UserToScalarAndIdx;
+  for (ExternalUser &EU : ExternalUses) {
+    UserToScalarAndIdx[EU.User].push_back({EU.Scalar, EU.Lane});
+    ScalarAndIdxToUser[{EU.Scalar, EU.Lane}].push_back(EU.User);
+  }
   for (ExternalUser &EU : ExternalUses) {
     // Uses by ephemeral values are free (because the ephemeral value will be
     // removed prior to code generation, and so the extraction will be
@@ -11739,8 +11750,9 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
       ExtraCost = TTI->getExtractWithExtendCost(Extend, EU.Scalar->getType(),
                                                 VecTy, EU.Lane);
     } else {
-      ExtraCost = TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy,
-                                          CostKind, EU.Lane);
+   ...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 8, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Sushant Gokhale (sushgokh)

Changes

…er possible

In case of Neon, if there exists extractelement from lane != 0 such that

  1. extractelement does not necessitate a move from vector_reg -> GPR
  2. extractelement result feeds into fmul
  3. Other operand of fmul is a scalar or extractelement from lane 0 or lane equivalent to 0
    then the extractelement can be merged with fmul in the backend and it incurs no cost.

e.g.

define double @<!-- -->foo(&lt;2 x double&gt; %a) { 
  %1 = extractelement &lt;2 x double&gt; %a, i32 0 
  %2 = extractelement &lt;2 x double&gt; %a, i32 1
  %res = fmul double %1, %2    
  ret double %res
}

%2 and %res can be merged in the backend to generate:
fmul d0, d0, v0.d[1]

The change was tested with SPEC FP(C/C++) on Neoverse-v2.
Compile time impact: None
Performance impact: Observing 1.3-1.7% uplift on lbm benchmark with -flto depending upon the config.


Patch is 31.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/111479.diff

9 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+32)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+10)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+14-3)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+17)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+159-4)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+19-2)
  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+14-2)
  • (modified) llvm/test/Analysis/CostModel/AArch64/extract_float.ll (+16-13)
  • (modified) llvm/test/Transforms/SLPVectorizer/consecutive-access.ll (+22-48)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 89a85bc8a90864..7a84c31679a888 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1392,6 +1392,19 @@ class TargetTransformInfo {
                                      unsigned Index = -1, Value *Op0 = nullptr,
                                      Value *Op1 = nullptr) const;
 
+  /// \return The expected cost of vector Insert and Extract.
+  /// Use -1 to indicate that there is no information on the index value.
+  /// This is used when the instruction is not available; a typical use
+  /// case is to provision the cost of vectorization/scalarization in
+  /// vectorizer passes.
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx) const;
+
   /// \return The expected cost of vector Insert and Extract.
   /// This is used when instruction is available, and implementation
   /// asserts 'I' is not nullptr.
@@ -2062,6 +2075,14 @@ class TargetTransformInfo::Concept {
                                              TTI::TargetCostKind CostKind,
                                              unsigned Index, Value *Op0,
                                              Value *Op1) = 0;
+
+  virtual InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx) = 0;
   virtual InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
                                              TTI::TargetCostKind CostKind,
                                              unsigned Index) = 0;
@@ -2726,6 +2747,17 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
                                      Value *Op1) override {
     return Impl.getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
   }
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx) override {
+    return Impl.getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1,
+                                   Scalar, ScalarAndIdxToUser,
+                                   UserToScalarAndIdx);
+  }
   InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
                                      TTI::TargetCostKind CostKind,
                                      unsigned Index) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 50040dc8f6165b..deeecc462e19c7 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -683,6 +683,16 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx) const {
+    return 1;
+  }
+
   InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
                                      TTI::TargetCostKind CostKind,
                                      unsigned Index) const {
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index c36a346c1b2e05..48bd91c5d8c6e5 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1277,12 +1277,23 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return 1;
   }
 
-  InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index, Value *Op0, Value *Op1) {
+  virtual InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
+                                             TTI::TargetCostKind CostKind,
+                                             unsigned Index, Value *Op0,
+                                             Value *Op1) {
     return getRegUsageForType(Val->getScalarType());
   }
 
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx) {
+    return getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
+  }
+
   InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
                                      TTI::TargetCostKind CostKind,
                                      unsigned Index) {
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index b5195f764cbd1c..7c588e717c4bba 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1037,6 +1037,23 @@ InstructionCost TargetTransformInfo::getVectorInstrCost(
   return Cost;
 }
 
+InstructionCost TargetTransformInfo::getVectorInstrCost(
+    unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+    Value *Op0, Value *Op1, Value *Scalar,
+    const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+        &ScalarAndIdxToUser,
+    const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+        &UserToScalarAndIdx) const {
+  // FIXME: Assert that Opcode is either InsertElement or ExtractElement.
+  // This is mentioned in the interface description and respected by all
+  // callers, but never asserted upon.
+  InstructionCost Cost = TTIImpl->getVectorInstrCost(
+      Opcode, Val, CostKind, Index, Op0, Op1, Scalar, ScalarAndIdxToUser,
+      UserToScalarAndIdx);
+  assert(Cost >= 0 && "TTI should not produce negative costs!");
+  return Cost;
+}
+
 InstructionCost
 TargetTransformInfo::getVectorInstrCost(const Instruction &I, Type *Val,
                                         TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 80d5168ae961ab..f86e7a3715b4e4 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -16,14 +16,20 @@
 #include "llvm/CodeGen/BasicTTIImpl.h"
 #include "llvm/CodeGen/CostTable.h"
 #include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsAArch64.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Transforms/InstCombine/InstCombiner.h"
 #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
 #include <algorithm>
+#include <cassert>
 #include <optional>
 using namespace llvm;
 using namespace llvm::PatternMatch;
@@ -3145,12 +3151,20 @@ InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
   return 0;
 }
 
-InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
-                                                         Type *Val,
-                                                         unsigned Index,
-                                                         bool HasRealUse) {
+InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
+    std::variant<const Instruction *, const unsigned> InstOrOpcode, Type *Val,
+    unsigned Index, bool HasRealUse, Value *Scalar,
+    const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+        &ScalarAndIdxToUser,
+    const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+        &UserToScalarAndIdx) {
   assert(Val->isVectorTy() && "This must be a vector type");
 
+  const Instruction *I =
+      (std::holds_alternative<const Instruction *>(InstOrOpcode)
+           ? get<const Instruction *>(InstOrOpcode)
+           : nullptr);
+
   if (Index != -1U) {
     // Legalize the type.
     std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val);
@@ -3194,6 +3208,134 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
     // compile-time considerations.
   }
 
+  // In case of Neon, if there exists extractelement from lane != 0 such that
+  // 1. extractelement does not necessitate a move from vector_reg -> GPR.
+  // 2. extractelement result feeds into fmul.
+  // 3. Other operand of fmul is a scalar or extractelement from lane 0 or lane
+  // equivalent to 0.
+  // then the extractelement can be merged with fmul in the backend and it
+  // incurs no cost.
+  // e.g.
+  // define double @foo(<2 x double> %a) {
+  //   %1 = extractelement <2 x double> %a, i32 0
+  //   %2 = extractelement <2 x double> %a, i32 1
+  //   %res = fmul double %1, %2
+  //   ret double %res
+  // }
+  // %2 and %res can be merged in the backend to generate fmul v0, v0, v1.d[1]
+  auto ExtractCanFuseWithFmul = [&]() {
+    // We bail out if the extract is from lane 0.
+    if (Index == 0)
+      return false;
+
+    // Check if the scalar element type of the vector operand of ExtractElement
+    // instruction is one of the allowed types.
+    auto IsAllowedScalarTy = [&](const Type *T) {
+      return T->isFloatTy() || T->isDoubleTy() ||
+             (T->isHalfTy() && ST->hasFullFP16());
+    };
+
+    // Check if the extractelement user is scalar fmul.
+    auto IsUserFMulScalarTy = [](const Value *EEUser) {
+      // Check if the user is scalar fmul.
+      const BinaryOperator *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
+      return BO && BO->getOpcode() == BinaryOperator::FMul &&
+             !BO->getType()->isVectorTy();
+    };
+
+    // InstCombine combines fmul with fadd/fsub. Hence, extractelement fusion
+    // with fmul does not happen.
+    auto IsFMulUserFAddFSub = [](const Value *FMul) {
+      return any_of(FMul->users(), [](const User *U) {
+        const BinaryOperator *BO = dyn_cast_if_present<BinaryOperator>(U);
+        return (BO && (BO->getOpcode() == BinaryOperator::FAdd ||
+                       BO->getOpcode() == BinaryOperator::FSub));
+      });
+    };
+
+    // Check if the type constraints on input vector type and result scalar type
+    // of extractelement instruction are satisfied.
+    auto TypeConstraintsOnEESatisfied =
+        [&IsAllowedScalarTy](const Type *VectorTy, const Type *ScalarTy) {
+          return isa<FixedVectorType>(VectorTy) && IsAllowedScalarTy(ScalarTy);
+        };
+
+    // Check if the extract index is from lane 0 or lane equivalent to 0 for a
+    // certain scalar type and a certain vector register width.
+    auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
+                                             const unsigned &EltSz) {
+      auto RegWidth =
+          getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
+              .getFixedValue();
+      return (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
+    };
+
+    if (std::holds_alternative<const unsigned>(InstOrOpcode)) {
+      if (!TypeConstraintsOnEESatisfied(Val, Val->getScalarType()))
+        return false;
+      const auto &ScalarIdxPair = std::make_pair(Scalar, Index);
+      return ScalarAndIdxToUser.find(ScalarIdxPair) !=
+                 ScalarAndIdxToUser.end() &&
+             all_of(ScalarAndIdxToUser.at(ScalarIdxPair), [&](Value *U) {
+               if (!IsUserFMulScalarTy(U) || IsFMulUserFAddFSub(U))
+                 return false;
+               // 1. Check if the other operand is extract from lane 0 or lane
+               // equivalent to 0.
+               // 2. In case of SLP, if the other operand is not extract from
+               // same tree, we bail out since we can not analyze that extract.
+               return UserToScalarAndIdx.at(U).size() == 2 &&
+                      all_of(UserToScalarAndIdx.at(U), [&](auto &P) {
+                        if (ScalarIdxPair == P)
+                          return true; // Skip.
+                        return IsExtractLaneEquivalentToZero(
+                            P.second, Val->getScalarSizeInBits());
+                      });
+             });
+    } else {
+      const ExtractElementInst *EE = cast<ExtractElementInst>(I);
+
+      const ConstantInt *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand());
+      if (!IdxOp)
+        return false;
+
+      if (!TypeConstraintsOnEESatisfied(EE->getVectorOperand()->getType(),
+                                        EE->getType()))
+        return false;
+
+      return !EE->users().empty() && all_of(EE->users(), [&](const User *U) {
+        if (!IsUserFMulScalarTy(U) || IsFMulUserFAddFSub(U))
+          return false;
+
+        // Check if the other operand of extractelement is also extractelement
+        // from lane equivalent to 0.
+        const BinaryOperator *BO = cast<BinaryOperator>(U);
+        const ExtractElementInst *OtherEE = dyn_cast<ExtractElementInst>(
+            BO->getOperand(0) == EE ? BO->getOperand(1) : BO->getOperand(0));
+        if (OtherEE) {
+          const ConstantInt *IdxOp =
+              dyn_cast<ConstantInt>(OtherEE->getIndexOperand());
+          if (!IdxOp)
+            return false;
+          return IsExtractLaneEquivalentToZero(
+              cast<ConstantInt>(OtherEE->getIndexOperand())
+                  ->getValue()
+                  .getZExtValue(),
+              OtherEE->getType()->getScalarSizeInBits());
+        }
+        return true;
+      });
+    }
+    return false;
+  };
+
+  if (std::holds_alternative<const unsigned>(InstOrOpcode)) {
+    const unsigned &Opcode = get<const unsigned>(InstOrOpcode);
+    if (Opcode == Instruction::ExtractElement && ExtractCanFuseWithFmul())
+      return 0;
+  } else if (I && I->getOpcode() == Instruction::ExtractElement &&
+             ExtractCanFuseWithFmul())
+    return 0;
+
   // All other insert/extracts cost this much.
   return ST->getVectorInsertExtractBaseCost();
 }
@@ -3207,6 +3349,19 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
   return getVectorInstrCostHelper(nullptr, Val, Index, HasRealUse);
 }
 
+InstructionCost AArch64TTIImpl::getVectorInstrCost(
+    unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+    Value *Op0, Value *Op1, Value *Scalar,
+    const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+        &ScalarAndIdxToUser,
+    const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+        &UserToScalarAndIdx) {
+  bool HasRealUse =
+      Opcode == Instruction::InsertElement && Op0 && !isa<UndefValue>(Op0);
+  return getVectorInstrCostHelper(Opcode, Val, Index, HasRealUse, Scalar,
+                                  ScalarAndIdxToUser, UserToScalarAndIdx);
+}
+
 InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
                                                    Type *Val,
                                                    TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 28e45207596ecd..12952ba41708b1 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -24,6 +24,7 @@
 #include "llvm/CodeGen/BasicTTIImpl.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Intrinsics.h"
+#include <climits>
 #include <cstdint>
 #include <optional>
 
@@ -66,8 +67,15 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   // 'Val' and 'Index' are forwarded from 'getVectorInstrCost'; 'HasRealUse'
   // indicates whether the vector instruction is available in the input IR or
   // just imaginary in vectorizer passes.
-  InstructionCost getVectorInstrCostHelper(const Instruction *I, Type *Val,
-                                           unsigned Index, bool HasRealUse);
+  InstructionCost getVectorInstrCostHelper(
+      std::variant<const Instruction *, const unsigned> InstOrOpcode, Type *Val,
+      unsigned Index, bool HasRealUse, Value *Scalar = nullptr,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser =
+              DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>(),
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx = DenseMap<
+              Value *, SmallVector<std::pair<Value *, unsigned>, 4>>());
 
 public:
   explicit AArch64TTIImpl(const AArch64TargetMachine *TM, const Function &F)
@@ -185,6 +193,15 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
                                      TTI::TargetCostKind CostKind,
                                      unsigned Index, Value *Op0, Value *Op1);
+
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx);
+
   InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
                                      TTI::TargetCostKind CostKind,
                                      unsigned Index);
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 401597af35bdac..90a070fc3b3c48 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -11633,6 +11633,17 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
   std::optional<DenseMap<Value *, unsigned>> ValueToExtUses;
   DenseMap<const TreeEntry *, DenseSet<Value *>> ExtractsCount;
   SmallPtrSet<Value *, 4> ScalarOpsFromCasts;
+  // Keep track of {Scalar, Index} -> User and User -> {Scalar, Index}.
+  // On AArch64, this helps in fusing a mov instruction, associated with
+  // extractelement, with fmul in the backend so that extractelement is free.
+  DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+      ScalarAndIdxToUser;
+  DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+      UserToScalarAndIdx;
+  for (ExternalUser &EU : ExternalUses) {
+    UserToScalarAndIdx[EU.User].push_back({EU.Scalar, EU.Lane});
+    ScalarAndIdxToUser[{EU.Scalar, EU.Lane}].push_back(EU.User);
+  }
   for (ExternalUser &EU : ExternalUses) {
     // Uses by ephemeral values are free (because the ephemeral value will be
     // removed prior to code generation, and so the extraction will be
@@ -11739,8 +11750,9 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
       ExtraCost = TTI->getExtractWithExtendCost(Extend, EU.Scalar->getType(),
                                                 VecTy, EU.Lane);
     } else {
-      ExtraCost = TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy,
-                                          CostKind, EU.Lane);
+   ...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 8, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Sushant Gokhale (sushgokh)

Changes

…er possible

In case of Neon, if there exists extractelement from lane != 0 such that

  1. extractelement does not necessitate a move from vector_reg -> GPR
  2. extractelement result feeds into fmul
  3. Other operand of fmul is a scalar or extractelement from lane 0 or lane equivalent to 0
    then the extractelement can be merged with fmul in the backend and it incurs no cost.

e.g.

define double @<!-- -->foo(&lt;2 x double&gt; %a) { 
  %1 = extractelement &lt;2 x double&gt; %a, i32 0 
  %2 = extractelement &lt;2 x double&gt; %a, i32 1
  %res = fmul double %1, %2    
  ret double %res
}

%2 and %res can be merged in the backend to generate:
fmul d0, d0, v0.d[1]

The change was tested with SPEC FP(C/C++) on Neoverse-v2.
Compile time impact: None
Performance impact: Observing 1.3-1.7% uplift on lbm benchmark with -flto depending upon the config.


Patch is 31.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/111479.diff

9 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+32)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+10)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+14-3)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+17)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+159-4)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+19-2)
  • (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+14-2)
  • (modified) llvm/test/Analysis/CostModel/AArch64/extract_float.ll (+16-13)
  • (modified) llvm/test/Transforms/SLPVectorizer/consecutive-access.ll (+22-48)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 89a85bc8a90864..7a84c31679a888 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1392,6 +1392,19 @@ class TargetTransformInfo {
                                      unsigned Index = -1, Value *Op0 = nullptr,
                                      Value *Op1 = nullptr) const;
 
+  /// \return The expected cost of vector Insert and Extract.
+  /// Use -1 to indicate that there is no information on the index value.
+  /// This is used when the instruction is not available; a typical use
+  /// case is to provision the cost of vectorization/scalarization in
+  /// vectorizer passes.
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx) const;
+
   /// \return The expected cost of vector Insert and Extract.
   /// This is used when instruction is available, and implementation
   /// asserts 'I' is not nullptr.
@@ -2062,6 +2075,14 @@ class TargetTransformInfo::Concept {
                                              TTI::TargetCostKind CostKind,
                                              unsigned Index, Value *Op0,
                                              Value *Op1) = 0;
+
+  virtual InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx) = 0;
   virtual InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
                                              TTI::TargetCostKind CostKind,
                                              unsigned Index) = 0;
@@ -2726,6 +2747,17 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
                                      Value *Op1) override {
     return Impl.getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
   }
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx) override {
+    return Impl.getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1,
+                                   Scalar, ScalarAndIdxToUser,
+                                   UserToScalarAndIdx);
+  }
   InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
                                      TTI::TargetCostKind CostKind,
                                      unsigned Index) override {
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 50040dc8f6165b..deeecc462e19c7 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -683,6 +683,16 @@ class TargetTransformInfoImplBase {
     return 1;
   }
 
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx) const {
+    return 1;
+  }
+
   InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
                                      TTI::TargetCostKind CostKind,
                                      unsigned Index) const {
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index c36a346c1b2e05..48bd91c5d8c6e5 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1277,12 +1277,23 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
     return 1;
   }
 
-  InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
-                                     TTI::TargetCostKind CostKind,
-                                     unsigned Index, Value *Op0, Value *Op1) {
+  virtual InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
+                                             TTI::TargetCostKind CostKind,
+                                             unsigned Index, Value *Op0,
+                                             Value *Op1) {
     return getRegUsageForType(Val->getScalarType());
   }
 
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx) {
+    return getVectorInstrCost(Opcode, Val, CostKind, Index, Op0, Op1);
+  }
+
   InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
                                      TTI::TargetCostKind CostKind,
                                      unsigned Index) {
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index b5195f764cbd1c..7c588e717c4bba 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1037,6 +1037,23 @@ InstructionCost TargetTransformInfo::getVectorInstrCost(
   return Cost;
 }
 
+InstructionCost TargetTransformInfo::getVectorInstrCost(
+    unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+    Value *Op0, Value *Op1, Value *Scalar,
+    const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+        &ScalarAndIdxToUser,
+    const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+        &UserToScalarAndIdx) const {
+  // FIXME: Assert that Opcode is either InsertElement or ExtractElement.
+  // This is mentioned in the interface description and respected by all
+  // callers, but never asserted upon.
+  InstructionCost Cost = TTIImpl->getVectorInstrCost(
+      Opcode, Val, CostKind, Index, Op0, Op1, Scalar, ScalarAndIdxToUser,
+      UserToScalarAndIdx);
+  assert(Cost >= 0 && "TTI should not produce negative costs!");
+  return Cost;
+}
+
 InstructionCost
 TargetTransformInfo::getVectorInstrCost(const Instruction &I, Type *Val,
                                         TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index 80d5168ae961ab..f86e7a3715b4e4 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -16,14 +16,20 @@
 #include "llvm/CodeGen/BasicTTIImpl.h"
 #include "llvm/CodeGen/CostTable.h"
 #include "llvm/CodeGen/TargetLowering.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/IntrinsicsAArch64.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Transforms/InstCombine/InstCombiner.h"
 #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
 #include <algorithm>
+#include <cassert>
 #include <optional>
 using namespace llvm;
 using namespace llvm::PatternMatch;
@@ -3145,12 +3151,20 @@ InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
   return 0;
 }
 
-InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
-                                                         Type *Val,
-                                                         unsigned Index,
-                                                         bool HasRealUse) {
+InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
+    std::variant<const Instruction *, const unsigned> InstOrOpcode, Type *Val,
+    unsigned Index, bool HasRealUse, Value *Scalar,
+    const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+        &ScalarAndIdxToUser,
+    const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+        &UserToScalarAndIdx) {
   assert(Val->isVectorTy() && "This must be a vector type");
 
+  const Instruction *I =
+      (std::holds_alternative<const Instruction *>(InstOrOpcode)
+           ? get<const Instruction *>(InstOrOpcode)
+           : nullptr);
+
   if (Index != -1U) {
     // Legalize the type.
     std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val);
@@ -3194,6 +3208,134 @@ InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
     // compile-time considerations.
   }
 
+  // In case of Neon, if there exists extractelement from lane != 0 such that
+  // 1. extractelement does not necessitate a move from vector_reg -> GPR.
+  // 2. extractelement result feeds into fmul.
+  // 3. Other operand of fmul is a scalar or extractelement from lane 0 or lane
+  // equivalent to 0.
+  // then the extractelement can be merged with fmul in the backend and it
+  // incurs no cost.
+  // e.g.
+  // define double @foo(<2 x double> %a) {
+  //   %1 = extractelement <2 x double> %a, i32 0
+  //   %2 = extractelement <2 x double> %a, i32 1
+  //   %res = fmul double %1, %2
+  //   ret double %res
+  // }
+  // %2 and %res can be merged in the backend to generate fmul v0, v0, v1.d[1]
+  auto ExtractCanFuseWithFmul = [&]() {
+    // We bail out if the extract is from lane 0.
+    if (Index == 0)
+      return false;
+
+    // Check if the scalar element type of the vector operand of ExtractElement
+    // instruction is one of the allowed types.
+    auto IsAllowedScalarTy = [&](const Type *T) {
+      return T->isFloatTy() || T->isDoubleTy() ||
+             (T->isHalfTy() && ST->hasFullFP16());
+    };
+
+    // Check if the extractelement user is scalar fmul.
+    auto IsUserFMulScalarTy = [](const Value *EEUser) {
+      // Check if the user is scalar fmul.
+      const BinaryOperator *BO = dyn_cast_if_present<BinaryOperator>(EEUser);
+      return BO && BO->getOpcode() == BinaryOperator::FMul &&
+             !BO->getType()->isVectorTy();
+    };
+
+    // InstCombine combines fmul with fadd/fsub. Hence, extractelement fusion
+    // with fmul does not happen.
+    auto IsFMulUserFAddFSub = [](const Value *FMul) {
+      return any_of(FMul->users(), [](const User *U) {
+        const BinaryOperator *BO = dyn_cast_if_present<BinaryOperator>(U);
+        return (BO && (BO->getOpcode() == BinaryOperator::FAdd ||
+                       BO->getOpcode() == BinaryOperator::FSub));
+      });
+    };
+
+    // Check if the type constraints on input vector type and result scalar type
+    // of extractelement instruction are satisfied.
+    auto TypeConstraintsOnEESatisfied =
+        [&IsAllowedScalarTy](const Type *VectorTy, const Type *ScalarTy) {
+          return isa<FixedVectorType>(VectorTy) && IsAllowedScalarTy(ScalarTy);
+        };
+
+    // Check if the extract index is from lane 0 or lane equivalent to 0 for a
+    // certain scalar type and a certain vector register width.
+    auto IsExtractLaneEquivalentToZero = [&](const unsigned &Idx,
+                                             const unsigned &EltSz) {
+      auto RegWidth =
+          getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
+              .getFixedValue();
+      return (Idx == 0 || (Idx * EltSz) % RegWidth == 0);
+    };
+
+    if (std::holds_alternative<const unsigned>(InstOrOpcode)) {
+      if (!TypeConstraintsOnEESatisfied(Val, Val->getScalarType()))
+        return false;
+      const auto &ScalarIdxPair = std::make_pair(Scalar, Index);
+      return ScalarAndIdxToUser.find(ScalarIdxPair) !=
+                 ScalarAndIdxToUser.end() &&
+             all_of(ScalarAndIdxToUser.at(ScalarIdxPair), [&](Value *U) {
+               if (!IsUserFMulScalarTy(U) || IsFMulUserFAddFSub(U))
+                 return false;
+               // 1. Check if the other operand is extract from lane 0 or lane
+               // equivalent to 0.
+               // 2. In case of SLP, if the other operand is not extract from
+               // same tree, we bail out since we can not analyze that extract.
+               return UserToScalarAndIdx.at(U).size() == 2 &&
+                      all_of(UserToScalarAndIdx.at(U), [&](auto &P) {
+                        if (ScalarIdxPair == P)
+                          return true; // Skip.
+                        return IsExtractLaneEquivalentToZero(
+                            P.second, Val->getScalarSizeInBits());
+                      });
+             });
+    } else {
+      const ExtractElementInst *EE = cast<ExtractElementInst>(I);
+
+      const ConstantInt *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand());
+      if (!IdxOp)
+        return false;
+
+      if (!TypeConstraintsOnEESatisfied(EE->getVectorOperand()->getType(),
+                                        EE->getType()))
+        return false;
+
+      return !EE->users().empty() && all_of(EE->users(), [&](const User *U) {
+        if (!IsUserFMulScalarTy(U) || IsFMulUserFAddFSub(U))
+          return false;
+
+        // Check if the other operand of extractelement is also extractelement
+        // from lane equivalent to 0.
+        const BinaryOperator *BO = cast<BinaryOperator>(U);
+        const ExtractElementInst *OtherEE = dyn_cast<ExtractElementInst>(
+            BO->getOperand(0) == EE ? BO->getOperand(1) : BO->getOperand(0));
+        if (OtherEE) {
+          const ConstantInt *IdxOp =
+              dyn_cast<ConstantInt>(OtherEE->getIndexOperand());
+          if (!IdxOp)
+            return false;
+          return IsExtractLaneEquivalentToZero(
+              cast<ConstantInt>(OtherEE->getIndexOperand())
+                  ->getValue()
+                  .getZExtValue(),
+              OtherEE->getType()->getScalarSizeInBits());
+        }
+        return true;
+      });
+    }
+    return false;
+  };
+
+  if (std::holds_alternative<const unsigned>(InstOrOpcode)) {
+    const unsigned &Opcode = get<const unsigned>(InstOrOpcode);
+    if (Opcode == Instruction::ExtractElement && ExtractCanFuseWithFmul())
+      return 0;
+  } else if (I && I->getOpcode() == Instruction::ExtractElement &&
+             ExtractCanFuseWithFmul())
+    return 0;
+
   // All other insert/extracts cost this much.
   return ST->getVectorInsertExtractBaseCost();
 }
@@ -3207,6 +3349,19 @@ InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
   return getVectorInstrCostHelper(nullptr, Val, Index, HasRealUse);
 }
 
+InstructionCost AArch64TTIImpl::getVectorInstrCost(
+    unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+    Value *Op0, Value *Op1, Value *Scalar,
+    const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+        &ScalarAndIdxToUser,
+    const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+        &UserToScalarAndIdx) {
+  bool HasRealUse =
+      Opcode == Instruction::InsertElement && Op0 && !isa<UndefValue>(Op0);
+  return getVectorInstrCostHelper(Opcode, Val, Index, HasRealUse, Scalar,
+                                  ScalarAndIdxToUser, UserToScalarAndIdx);
+}
+
 InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
                                                    Type *Val,
                                                    TTI::TargetCostKind CostKind,
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
index 28e45207596ecd..12952ba41708b1 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
@@ -24,6 +24,7 @@
 #include "llvm/CodeGen/BasicTTIImpl.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Intrinsics.h"
+#include <climits>
 #include <cstdint>
 #include <optional>
 
@@ -66,8 +67,15 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   // 'Val' and 'Index' are forwarded from 'getVectorInstrCost'; 'HasRealUse'
   // indicates whether the vector instruction is available in the input IR or
   // just imaginary in vectorizer passes.
-  InstructionCost getVectorInstrCostHelper(const Instruction *I, Type *Val,
-                                           unsigned Index, bool HasRealUse);
+  InstructionCost getVectorInstrCostHelper(
+      std::variant<const Instruction *, const unsigned> InstOrOpcode, Type *Val,
+      unsigned Index, bool HasRealUse, Value *Scalar = nullptr,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser =
+              DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>(),
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx = DenseMap<
+              Value *, SmallVector<std::pair<Value *, unsigned>, 4>>());
 
 public:
   explicit AArch64TTIImpl(const AArch64TargetMachine *TM, const Function &F)
@@ -185,6 +193,15 @@ class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> {
   InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
                                      TTI::TargetCostKind CostKind,
                                      unsigned Index, Value *Op0, Value *Op1);
+
+  InstructionCost getVectorInstrCost(
+      unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
+      Value *Op0, Value *Op1, Value *Scalar,
+      const DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+          &ScalarAndIdxToUser,
+      const DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+          &UserToScalarAndIdx);
+
   InstructionCost getVectorInstrCost(const Instruction &I, Type *Val,
                                      TTI::TargetCostKind CostKind,
                                      unsigned Index);
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 401597af35bdac..90a070fc3b3c48 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -11633,6 +11633,17 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
   std::optional<DenseMap<Value *, unsigned>> ValueToExtUses;
   DenseMap<const TreeEntry *, DenseSet<Value *>> ExtractsCount;
   SmallPtrSet<Value *, 4> ScalarOpsFromCasts;
+  // Keep track of {Scalar, Index} -> User and User -> {Scalar, Index}.
+  // On AArch64, this helps in fusing a mov instruction, associated with
+  // extractelement, with fmul in the backend so that extractelement is free.
+  DenseMap<std::pair<Value *, unsigned>, SmallVector<Value *, 4>>
+      ScalarAndIdxToUser;
+  DenseMap<Value *, SmallVector<std::pair<Value *, unsigned>, 4>>
+      UserToScalarAndIdx;
+  for (ExternalUser &EU : ExternalUses) {
+    UserToScalarAndIdx[EU.User].push_back({EU.Scalar, EU.Lane});
+    ScalarAndIdxToUser[{EU.Scalar, EU.Lane}].push_back(EU.User);
+  }
   for (ExternalUser &EU : ExternalUses) {
     // Uses by ephemeral values are free (because the ephemeral value will be
     // removed prior to code generation, and so the extraction will be
@@ -11739,8 +11750,9 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
       ExtraCost = TTI->getExtractWithExtendCost(Extend, EU.Scalar->getType(),
                                                 VecTy, EU.Lane);
     } else {
-      ExtraCost = TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy,
-                                          CostKind, EU.Lane);
+   ...
[truncated]

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp Outdated Show resolved Hide resolved
llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h Outdated Show resolved Hide resolved
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp Outdated Show resolved Hide resolved
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp Outdated Show resolved Hide resolved
@sushgokh sushgokh force-pushed the extractelement-costing-with-fmul-user branch from 2471f47 to 85ca8b2 Compare October 9, 2024 07:57
Copy link

github-actions bot commented Oct 9, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 924a64a3486f9962c42d4ec253774eb2c586ac33 08e28793106f65f4cd7eca281ca62bde1cccce40 --extensions cpp,h -- llvm/include/llvm/Analysis/TargetTransformInfo.h llvm/include/llvm/Analysis/TargetTransformInfoImpl.h llvm/include/llvm/CodeGen/BasicTTIImpl.h llvm/lib/Analysis/TargetTransformInfo.cpp llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index dac3d4a009..736caf9e4d 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -31,7 +31,7 @@
 #include "llvm/Transforms/InstCombine/InstCombiner.h"
 #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
 #include <algorithm>
-//#include <cassert>
+// #include <cassert>
 #include <optional>
 using namespace llvm;
 using namespace llvm::PatternMatch;

Comment on lines 25 to 26
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you actually need these includes here or the user of the header file already have these includes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for some users, thats redundant. Will remove redundant usages. Thanks.

llvm/include/llvm/Analysis/TargetTransformInfo.h Outdated Show resolved Hide resolved
llvm/include/llvm/CodeGen/BasicTTIImpl.h Outdated Show resolved Hide resolved
llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp Outdated Show resolved Hide resolved
@sushgokh sushgokh force-pushed the extractelement-costing-with-fmul-user branch from 85ca8b2 to c88143d Compare October 10, 2024 05:53
…er possible

In case of Neon, if there exists extractelement from lane != 0 such that
  1. extractelement does not necessitate a move from vector_reg -> GPR.
  2. extractelement result feeds into fmul.
  3. Other operand of fmul is a scalar or extractelement from lane 0 or lane equivalent to 0.
  then the extractelement can be merged with fmul in the backend and it incurs no cost.
  e.g.
  define double @foo(<2 x double> %a) {
    %1 = extractelement <2 x double> %a, i32 0
    %2 = extractelement <2 x double> %a, i32 1
    %res = fmul double %1, %2    ret double %res
  }
  %2 and %res can be merged in the backend to generate:
  fmul    d0, d0, v0.d[1]

The change was tested with SPEC FP(C/C++) on Neoverse-v2.
Compile time impact: None
Performance impact: Observing 1.3-1.7% uplift on lbm benchmark with -flto depending upon the config.
@sushgokh sushgokh force-pushed the extractelement-costing-with-fmul-user branch from c88143d to 08e2879 Compare October 10, 2024 05:59
@sushgokh
Copy link
Contributor Author

ping

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants