Skip to content

[SDAG] Split the partial reduce legalize table by opcode [nfc] #141970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 29, 2025

Conversation

preames
Copy link
Collaborator

@preames preames commented May 29, 2025

On it's own, this change should be non-functional. This is a preparatory change for #141267 which adds a new form of PARTIAL_REDUCE_*MLA. As noted in the discussion on that review, AArch64 needs a different set of legal and custom types for the PARTIAL_REDUCE_SUMLA variant than the currently existing PARTIAL_REDUCE_UMLA/SMLA.

On it's own, this change should be non-functional. This is a preparatory change for llvm#141267 which adds a new form of PARTIAL_REDUCE_*MLA.  As noted in the discussion on that review, AArch64 needs a different set of legal and custom types for the PARTIAL_REDUCE_SUMLA variant than the currently existing PARTIAL_REDUCE_UMLA/SMLA.
@llvmbot
Copy link
Member

llvmbot commented May 29, 2025

@llvm/pr-subscribers-backend-risc-v

@llvm/pr-subscribers-backend-aarch64

Author: Philip Reames (preames)

Changes

On it's own, this change should be non-functional. This is a preparatory change for #141267 which adds a new form of PARTIAL_REDUCE_*MLA. As noted in the discussion on that review, AArch64 needs a different set of legal and custom types for the PARTIAL_REDUCE_SUMLA variant than the currently existing PARTIAL_REDUCE_UMLA/SMLA.


Full diff: https://github.com/llvm/llvm-project/pull/141970.diff

5 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+22-13)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+14-13)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (+3-2)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+14-9)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+8-6)
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index b818f4768c2c3..9c453f51e129d 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -1659,17 +1659,20 @@ class LLVM_ABI TargetLoweringBase {
   /// InputVT should be treated. Either it's legal, needs to be promoted to a
   /// larger size, needs to be expanded to some other code sequence, or the
   /// target has a custom expander for it.
-  LegalizeAction getPartialReduceMLAAction(EVT AccVT, EVT InputVT) const {
-    PartialReduceActionTypes TypePair = {AccVT.getSimpleVT().SimpleTy,
-                                         InputVT.getSimpleVT().SimpleTy};
-    auto It = PartialReduceMLAActions.find(TypePair);
+  LegalizeAction getPartialReduceMLAAction(unsigned Opc, EVT AccVT,
+                                           EVT InputVT) const {
+    assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
+    PartialReduceActionTypes Key = {Opc, AccVT.getSimpleVT().SimpleTy,
+                                    InputVT.getSimpleVT().SimpleTy};
+    auto It = PartialReduceMLAActions.find(Key);
     return It != PartialReduceMLAActions.end() ? It->second : Expand;
   }
 
   /// Return true if a PARTIAL_REDUCE_U/SMLA node with the specified types is
   /// legal or custom for this target.
-  bool isPartialReduceMLALegalOrCustom(EVT AccVT, EVT InputVT) const {
-    LegalizeAction Action = getPartialReduceMLAAction(AccVT, InputVT);
+  bool isPartialReduceMLALegalOrCustom(unsigned Opc, EVT AccVT,
+                                       EVT InputVT) const {
+    LegalizeAction Action = getPartialReduceMLAAction(Opc, AccVT, InputVT);
     return Action == Legal || Action == Custom;
   }
 
@@ -2754,12 +2757,18 @@ class LLVM_ABI TargetLoweringBase {
   /// type InputVT should be treated by the target. Either it's legal, needs to
   /// be promoted to a larger size, needs to be expanded to some other code
   /// sequence, or the target has a custom expander for it.
-  void setPartialReduceMLAAction(MVT AccVT, MVT InputVT,
+  void setPartialReduceMLAAction(unsigned Opc, MVT AccVT, MVT InputVT,
                                  LegalizeAction Action) {
+    assert(Opc == ISD::PARTIAL_REDUCE_SMLA || Opc == ISD::PARTIAL_REDUCE_UMLA);
     assert(AccVT.isValid() && InputVT.isValid() &&
            "setPartialReduceMLAAction types aren't valid");
-    PartialReduceActionTypes TypePair = {AccVT.SimpleTy, InputVT.SimpleTy};
-    PartialReduceMLAActions[TypePair] = Action;
+    PartialReduceActionTypes Key = {Opc, AccVT.SimpleTy, InputVT.SimpleTy};
+    PartialReduceMLAActions[Key] = Action;
+  }
+  void setPartialReduceMLAAction(ArrayRef<unsigned> Opcodes, MVT AccVT,
+                                 MVT InputVT, LegalizeAction Action) {
+    for (unsigned Opc : Opcodes)
+      setPartialReduceMLAAction(Opc, AccVT, InputVT, Action);
   }
 
   /// If Opc/OrigVT is specified as being promoted, the promotion code defaults
@@ -3751,10 +3760,10 @@ class LLVM_ABI TargetLoweringBase {
   uint32_t CondCodeActions[ISD::SETCC_INVALID][(MVT::VALUETYPE_SIZE + 7) / 8];
 
   using PartialReduceActionTypes =
-      std::pair<MVT::SimpleValueType, MVT::SimpleValueType>;
-  /// For each result type and input type for the ISD::PARTIAL_REDUCE_U/SMLA
-  /// nodes, keep a LegalizeAction which indicates how instruction selection
-  /// should deal with this operation.
+      std::tuple<unsigned, MVT::SimpleValueType, MVT::SimpleValueType>;
+  /// For each partial reduce opcode, result type and input type combination,
+  /// keep a LegalizeAction which indicates how instruction selection should
+  /// deal with this operation.
   DenseMap<PartialReduceActionTypes, LegalizeAction> PartialReduceMLAActions;
 
   ValueTypeActionImpl ValueTypeActions;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index e05f85ea3bd8e..be2209a2f8faf 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -12673,17 +12673,17 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
   SDValue LHSExtOp = LHS->getOperand(0);
   EVT LHSExtOpVT = LHSExtOp.getValueType();
 
+  bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
+  unsigned NewOpcode =
+      ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+
   // Only perform these combines if the target supports folding
   // the extends into the operation.
   if (!TLI.isPartialReduceMLALegalOrCustom(
-          TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
+          NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
           TLI.getTypeToTransformTo(*Context, LHSExtOpVT)))
     return SDValue();
 
-  bool ExtIsSigned = LHSOpcode == ISD::SIGN_EXTEND;
-  unsigned NewOpcode =
-      ExtIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
-
   // partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
   // -> partial_reduce_*mla(acc, x, C)
   if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
@@ -12737,14 +12737,6 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
   if (!ISD::isExtOpcode(Op1Opcode))
     return SDValue();
 
-  SDValue UnextOp1 = Op1.getOperand(0);
-  EVT UnextOp1VT = UnextOp1.getValueType();
-  auto *Context = DAG.getContext();
-  if (!TLI.isPartialReduceMLALegalOrCustom(
-          TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
-          TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
-    return SDValue();
-
   bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
   bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
   EVT AccElemVT = Acc.getValueType().getVectorElementType();
@@ -12754,6 +12746,15 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
 
   unsigned NewOpcode =
       Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+
+  SDValue UnextOp1 = Op1.getOperand(0);
+  EVT UnextOp1VT = UnextOp1.getValueType();
+  auto *Context = DAG.getContext();
+  if (!TLI.isPartialReduceMLALegalOrCustom(
+          NewOpcode, TLI.getTypeToTransformTo(*Context, N->getValueType(0)),
+          TLI.getTypeToTransformTo(*Context, UnextOp1VT)))
+    return SDValue();
+
   return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
                      DAG.getConstant(1, DL, UnextOp1VT));
 }
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index affcd78ea61b0..910a40e5b5141 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -530,8 +530,9 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   }
   case ISD::PARTIAL_REDUCE_UMLA:
   case ISD::PARTIAL_REDUCE_SMLA:
-    Action = TLI.getPartialReduceMLAAction(Node->getValueType(0),
-                                           Node->getOperand(1).getValueType());
+    Action =
+        TLI.getPartialReduceMLAAction(Op.getOpcode(), Node->getValueType(0),
+                                      Node->getOperand(1).getValueType());
     break;
 
 #define BEGIN_REGISTER_VP_SDNODE(VPID, LEGALPOS, ...)                          \
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a07afea963e20..f18d325148742 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1458,9 +1458,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::FADD, VT, Custom);
 
     if (EnablePartialReduceNodes && Subtarget->hasDotProd()) {
-      setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
-      setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
-      setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
+      static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
+                                        ISD::PARTIAL_REDUCE_UMLA};
+
+      setPartialReduceMLAAction(MLAOps, MVT::v4i32, MVT::v16i8, Legal);
+      setPartialReduceMLAAction(MLAOps, MVT::v2i32, MVT::v8i8, Legal);
+      setPartialReduceMLAAction(MLAOps, MVT::v2i64, MVT::v16i8, Custom);
     }
 
   } else /* !isNeonAvailable */ {
@@ -1881,16 +1884,18 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   if (EnablePartialReduceNodes && Subtarget->isSVEorStreamingSVEAvailable()) {
     // Mark known legal pairs as 'Legal' (these will expand to UDOT or SDOT).
     // Other pairs will default to 'Expand'.
-    setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
-    setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
+    static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
+                                      ISD::PARTIAL_REDUCE_UMLA};
+    setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv8i16, Legal);
+    setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Legal);
 
-    setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
+    setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv16i8, Custom);
 
     // Wide add types
     if (Subtarget->hasSVE2() || Subtarget->hasSME()) {
-      setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv4i32, Legal);
-      setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv8i16, Legal);
-      setPartialReduceMLAAction(MVT::nxv8i16, MVT::nxv16i8, Legal);
+      setPartialReduceMLAAction(MLAOps, MVT::nxv2i64, MVT::nxv4i32, Legal);
+      setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv8i16, Legal);
+      setPartialReduceMLAAction(MLAOps, MVT::nxv8i16, MVT::nxv16i8, Legal);
     }
   }
 
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 43c81b97a0e05..567f4c5b47d30 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1573,11 +1573,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
   // zve32x is broken for partial_reduce_umla, but let's not make it worse.
   if (Subtarget.hasStdExtZvqdotq() && Subtarget.getELen() >= 64) {
-    setPartialReduceMLAAction(MVT::nxv1i32, MVT::nxv4i8, Custom);
-    setPartialReduceMLAAction(MVT::nxv2i32, MVT::nxv8i8, Custom);
-    setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Custom);
-    setPartialReduceMLAAction(MVT::nxv8i32, MVT::nxv32i8, Custom);
-    setPartialReduceMLAAction(MVT::nxv16i32, MVT::nxv64i8, Custom);
+    static const unsigned MLAOps[] = {ISD::PARTIAL_REDUCE_SMLA,
+                                      ISD::PARTIAL_REDUCE_UMLA};
+    setPartialReduceMLAAction(MLAOps, MVT::nxv1i32, MVT::nxv4i8, Custom);
+    setPartialReduceMLAAction(MLAOps, MVT::nxv2i32, MVT::nxv8i8, Custom);
+    setPartialReduceMLAAction(MLAOps, MVT::nxv4i32, MVT::nxv16i8, Custom);
+    setPartialReduceMLAAction(MLAOps, MVT::nxv8i32, MVT::nxv32i8, Custom);
+    setPartialReduceMLAAction(MLAOps, MVT::nxv16i32, MVT::nxv64i8, Custom);
 
     if (Subtarget.useRVVForFixedLengthVectors()) {
       for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) {
@@ -1586,7 +1588,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
           continue;
         ElementCount EC = VT.getVectorElementCount();
         MVT ArgVT = MVT::getVectorVT(MVT::i8, EC.multiplyCoefficientBy(4));
-        setPartialReduceMLAAction(VT, ArgVT, Custom);
+        setPartialReduceMLAAction(MLAOps, VT, ArgVT, Custom);
       }
     }
   }

Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@preames preames merged commit 1651aa2 into llvm:main May 29, 2025
10 checks passed
@preames preames deleted the pr-partial-reduce-split-legalize branch May 29, 2025 21:05
google-yfyang pushed a commit to google-yfyang/llvm-project that referenced this pull request May 29, 2025
…141970)

On it's own, this change should be non-functional. This is a preparatory
change for llvm#141267 which adds a
new form of PARTIAL_REDUCE_*MLA. As noted in the discussion on that
review, AArch64 needs a different set of legal and custom types for the
PARTIAL_REDUCE_SUMLA variant than the currently existing
PARTIAL_REDUCE_UMLA/SMLA.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Jun 3, 2025
…141970)

On it's own, this change should be non-functional. This is a preparatory
change for llvm#141267 which adds a
new form of PARTIAL_REDUCE_*MLA. As noted in the discussion on that
review, AArch64 needs a different set of legal and custom types for the
PARTIAL_REDUCE_SUMLA variant than the currently existing
PARTIAL_REDUCE_UMLA/SMLA.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants