Skip to content

Conversation

woruyu
Copy link
Member

@woruyu woruyu commented Jul 4, 2025

Summary

This PR resolves #146871
This PR resolves #140745

Refactor m_Zero/m_One/m_AllOnes all use struct template function to match and AllowUndefs=false as default.

@llvmbot llvmbot added backend:X86 llvm:SelectionDAG SelectionDAGISel as well labels Jul 4, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2025

@llvm/pr-subscribers-backend-x86

Author: woruyu (woruyu)

Changes

This PR resolves #146871


Full diff: https://github.com/llvm/llvm-project/pull/147044.diff

5 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/SDPatternMatch.h (+33-6)
  • (modified) llvm/include/llvm/CodeGen/SelectionDAGNodes.h (+4)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+1-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+12)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+9-11)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 35322c32a8283..7c5cdbbeb0ca8 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1100,19 +1100,46 @@ inline SpecificInt_match m_SpecificInt(uint64_t V) {
   return SpecificInt_match(APInt(64, V));
 }
 
-inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); }
-inline SpecificInt_match m_One() { return m_SpecificInt(1U); }
+struct Zero_match {
+  bool AllowUndefs;
+
+  explicit Zero_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &, SDValue N) const {
+    return isZeroOrZeroSplat(N, AllowUndefs);
+  }
+};
+
+struct Ones_match {
+  bool AllowUndefs;
+
+  Ones_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
+
+  template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+    return isOnesOrOnesSplat(N, AllowUndefs);
+  }
+};
 
 struct AllOnes_match {
+  bool AllowUndefs;
 
-  AllOnes_match() = default;
+  AllOnes_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
 
   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
-    return isAllOnesOrAllOnesSplat(N);
+    return isAllOnesOrAllOnesSplat(N, AllowUndefs);
   }
 };
 
-inline AllOnes_match m_AllOnes() { return AllOnes_match(); }
+inline Ones_match m_One(bool AllowUndefs = false) {
+  return Ones_match(AllowUndefs);
+}
+inline Zero_match m_Zero(bool AllowUndefs = false) {
+  return Zero_match(AllowUndefs);
+}
+inline AllOnes_match m_AllOnes(bool AllowUndefs = false) {
+  return AllOnes_match(AllowUndefs);
+}
 
 /// Match true boolean value based on the information provided by
 /// TargetLowering.
@@ -1189,7 +1216,7 @@ inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
 
 /// Match a negate as a sub(0, v)
 template <typename ValTy>
-inline BinaryOpc_match<SpecificInt_match, ValTy> m_Neg(const ValTy &V) {
+inline BinaryOpc_match<Zero_match, ValTy, false> m_Neg(const ValTy &V) {
   return m_Sub(m_Zero(), V);
 }
 
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index a3675eecfea3f..6bfc40afeb55e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1937,6 +1937,10 @@ LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);
 /// Does not permit build vector implicit truncation.
 LLVM_ABI bool isAllOnesOrAllOnesSplat(SDValue V, bool AllowUndefs = false);
 
+LLVM_ABI bool isOnesOrOnesSplat(SDValue N, bool AllowUndefs = false);
+
+LLVM_ABI bool isZeroOrZeroSplat(SDValue N, bool AllowUndefs = false);
+
 /// Return true if \p V is either a integer or FP constant.
 inline bool isIntOrFPConstant(SDValue V) {
   return isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V);
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 586eb2f3cf45e..db53fb92ae08b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4281,7 +4281,7 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
     return V;
 
   // (A - B) - 1  ->  add (xor B, -1), A
-  if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One())))
+  if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One(true))))
     return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT));
 
   // Look for:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 2a3c8e2b011ad..d6605c3ec77dd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -12569,6 +12569,18 @@ bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
   return C && C->isAllOnes() && C->getValueSizeInBits(0) == BitWidth;
 }
 
+bool llvm::isOnesOrOnesSplat(SDValue N, bool AllowUndefs) {
+  N = peekThroughBitcasts(N);
+  ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs);
+  return C && C->getAPIntValue() == 1;
+}
+
+bool llvm::isZeroOrZeroSplat(SDValue N, bool AllowUndefs) {
+  N = peekThroughBitcasts(N);
+  ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs, true);
+  return C && C->isZero();
+}
+
 HandleSDNode::~HandleSDNode() {
   DropOperands();
 }
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6edbb7b1bae95..1128406236a20 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -57923,22 +57923,20 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
     }
   }
 
+  SDValue X, Y;
+
   // add(psadbw(X,0),psadbw(Y,0)) -> psadbw(add(X,Y),0)
   // iff X and Y won't overflow.
-  if (Op0.getOpcode() == X86ISD::PSADBW && Op1.getOpcode() == X86ISD::PSADBW &&
-      ISD::isBuildVectorAllZeros(Op0.getOperand(1).getNode()) &&
-      ISD::isBuildVectorAllZeros(Op1.getOperand(1).getNode())) {
-    if (DAG.willNotOverflowAdd(false, Op0.getOperand(0), Op1.getOperand(0))) {
-      MVT OpVT = Op0.getOperand(1).getSimpleValueType();
-      SDValue Sum =
-          DAG.getNode(ISD::ADD, DL, OpVT, Op0.getOperand(0), Op1.getOperand(0));
-      return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
-                         getZeroVector(OpVT, Subtarget, DAG, DL));
-    }
+  if (sd_match(Op0, m_c_BinOp(X86ISD::PSADBW, m_Value(X), m_Zero())) &&
+      sd_match(Op1, m_c_BinOp(X86ISD::PSADBW, m_Value(Y), m_Zero())) &&
+      DAG.willNotOverflowAdd(/*IsSigned=*/false, X, Y)) {
+    MVT OpVT = X.getSimpleValueType();
+    SDValue Sum = DAG.getNode(ISD::ADD, DL, OpVT, X, Y);
+    return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
+                       getZeroVector(OpVT, Subtarget, DAG, DL));
   }
 
   if (VT.isVector()) {
-    SDValue X, Y;
     EVT BoolVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
                                   VT.getVectorElementCount());
 

@llvmbot
Copy link
Member

llvmbot commented Jul 4, 2025

@llvm/pr-subscribers-llvm-selectiondag

Author: woruyu (woruyu)

Changes

This PR resolves #146871


Full diff: https://github.com/llvm/llvm-project/pull/147044.diff

5 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/SDPatternMatch.h (+33-6)
  • (modified) llvm/include/llvm/CodeGen/SelectionDAGNodes.h (+4)
  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+1-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+12)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+9-11)
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h
index 35322c32a8283..7c5cdbbeb0ca8 100644
--- a/llvm/include/llvm/CodeGen/SDPatternMatch.h
+++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h
@@ -1100,19 +1100,46 @@ inline SpecificInt_match m_SpecificInt(uint64_t V) {
   return SpecificInt_match(APInt(64, V));
 }
 
-inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); }
-inline SpecificInt_match m_One() { return m_SpecificInt(1U); }
+struct Zero_match {
+  bool AllowUndefs;
+
+  explicit Zero_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
+
+  template <typename MatchContext>
+  bool match(const MatchContext &, SDValue N) const {
+    return isZeroOrZeroSplat(N, AllowUndefs);
+  }
+};
+
+struct Ones_match {
+  bool AllowUndefs;
+
+  Ones_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
+
+  template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
+    return isOnesOrOnesSplat(N, AllowUndefs);
+  }
+};
 
 struct AllOnes_match {
+  bool AllowUndefs;
 
-  AllOnes_match() = default;
+  AllOnes_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
 
   template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
-    return isAllOnesOrAllOnesSplat(N);
+    return isAllOnesOrAllOnesSplat(N, AllowUndefs);
   }
 };
 
-inline AllOnes_match m_AllOnes() { return AllOnes_match(); }
+inline Ones_match m_One(bool AllowUndefs = false) {
+  return Ones_match(AllowUndefs);
+}
+inline Zero_match m_Zero(bool AllowUndefs = false) {
+  return Zero_match(AllowUndefs);
+}
+inline AllOnes_match m_AllOnes(bool AllowUndefs = false) {
+  return AllOnes_match(AllowUndefs);
+}
 
 /// Match true boolean value based on the information provided by
 /// TargetLowering.
@@ -1189,7 +1216,7 @@ inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
 
 /// Match a negate as a sub(0, v)
 template <typename ValTy>
-inline BinaryOpc_match<SpecificInt_match, ValTy> m_Neg(const ValTy &V) {
+inline BinaryOpc_match<Zero_match, ValTy, false> m_Neg(const ValTy &V) {
   return m_Sub(m_Zero(), V);
 }
 
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index a3675eecfea3f..6bfc40afeb55e 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -1937,6 +1937,10 @@ LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);
 /// Does not permit build vector implicit truncation.
 LLVM_ABI bool isAllOnesOrAllOnesSplat(SDValue V, bool AllowUndefs = false);
 
+LLVM_ABI bool isOnesOrOnesSplat(SDValue N, bool AllowUndefs = false);
+
+LLVM_ABI bool isZeroOrZeroSplat(SDValue N, bool AllowUndefs = false);
+
 /// Return true if \p V is either a integer or FP constant.
 inline bool isIntOrFPConstant(SDValue V) {
   return isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V);
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 586eb2f3cf45e..db53fb92ae08b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4281,7 +4281,7 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
     return V;
 
   // (A - B) - 1  ->  add (xor B, -1), A
-  if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One())))
+  if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One(true))))
     return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT));
 
   // Look for:
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 2a3c8e2b011ad..d6605c3ec77dd 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -12569,6 +12569,18 @@ bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
   return C && C->isAllOnes() && C->getValueSizeInBits(0) == BitWidth;
 }
 
+bool llvm::isOnesOrOnesSplat(SDValue N, bool AllowUndefs) {
+  N = peekThroughBitcasts(N);
+  ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs);
+  return C && C->getAPIntValue() == 1;
+}
+
+bool llvm::isZeroOrZeroSplat(SDValue N, bool AllowUndefs) {
+  N = peekThroughBitcasts(N);
+  ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs, true);
+  return C && C->isZero();
+}
+
 HandleSDNode::~HandleSDNode() {
   DropOperands();
 }
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 6edbb7b1bae95..1128406236a20 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -57923,22 +57923,20 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
     }
   }
 
+  SDValue X, Y;
+
   // add(psadbw(X,0),psadbw(Y,0)) -> psadbw(add(X,Y),0)
   // iff X and Y won't overflow.
-  if (Op0.getOpcode() == X86ISD::PSADBW && Op1.getOpcode() == X86ISD::PSADBW &&
-      ISD::isBuildVectorAllZeros(Op0.getOperand(1).getNode()) &&
-      ISD::isBuildVectorAllZeros(Op1.getOperand(1).getNode())) {
-    if (DAG.willNotOverflowAdd(false, Op0.getOperand(0), Op1.getOperand(0))) {
-      MVT OpVT = Op0.getOperand(1).getSimpleValueType();
-      SDValue Sum =
-          DAG.getNode(ISD::ADD, DL, OpVT, Op0.getOperand(0), Op1.getOperand(0));
-      return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
-                         getZeroVector(OpVT, Subtarget, DAG, DL));
-    }
+  if (sd_match(Op0, m_c_BinOp(X86ISD::PSADBW, m_Value(X), m_Zero())) &&
+      sd_match(Op1, m_c_BinOp(X86ISD::PSADBW, m_Value(Y), m_Zero())) &&
+      DAG.willNotOverflowAdd(/*IsSigned=*/false, X, Y)) {
+    MVT OpVT = X.getSimpleValueType();
+    SDValue Sum = DAG.getNode(ISD::ADD, DL, OpVT, X, Y);
+    return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
+                       getZeroVector(OpVT, Subtarget, DAG, DL));
   }
 
   if (VT.isVector()) {
-    SDValue X, Y;
     EVT BoolVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
                                   VT.getVectorElementCount());
 

@woruyu woruyu force-pushed the feat/SDPatternMatch-m_Zero-m_One-m_AllOnes branch from e71c622 to 229ec28 Compare July 7, 2025 02:18
@woruyu
Copy link
Member Author

woruyu commented Jul 7, 2025

Hello, any suggestion for this PR, thank you! @RKSimon @mshockwave

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should get tests in unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

@woruyu
Copy link
Member Author

woruyu commented Jul 7, 2025

Probably should get tests in unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

That makes sense,I think adding test for m_Zero/m_One/m_AllOnes default behavior and supportting peekthrough bitcast is a better choice, I will add it!

Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - cheers

@RKSimon RKSimon merged commit c80fa23 into llvm:main Jul 7, 2025
9 checks passed
mshockwave pushed a commit that referenced this pull request Jul 8, 2025
…hTest (#147443)

### Summary
This PR remove the extra llvm::SDPatternMatch prefix in
#147044
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jul 8, 2025
…PatternMatchTest (#147443)

### Summary
This PR remove the extra llvm::SDPatternMatch prefix in
llvm/llvm-project#147044
@RKSimon
Copy link
Collaborator

RKSimon commented Jul 8, 2025

@woruyu do you have the add(psadbw(X,0),psadbw(Y,0)) -> psadbw(add(X,Y),0) patch somewhere?

@woruyu
Copy link
Member Author

woruyu commented Jul 8, 2025

@woruyu do you have the add(psadbw(X,0),psadbw(Y,0)) -> psadbw(add(X,Y),0) patch somewhere?

I originally added the SDPattern for add(psadbw(X,0), psadbw(Y,0)) -> psadbw(add(X,Y),0) in commit e71c622, but removed it during a force-push in 229ec28 based on review feedback. It’s not included in the final PR.

@RKSimon
Copy link
Collaborator

RKSimon commented Jul 8, 2025

Sorry for the confusion - I meant that I asked you in the review to remove it from #147044 but then create a separate PR for it as a followup

@woruyu
Copy link
Member Author

woruyu commented Jul 8, 2025

Sorry for the confusion - I meant that I asked you in the review to remove it from #147044 but then create a separate PR for it as a followup

Thanks a lot for the clarification — understand, I’d be happy to submit it as a follow-up patch.

@woruyu
Copy link
Member Author

woruyu commented Jul 9, 2025

Sorry for the confusion - I meant that I asked you in the review to remove it from #147044 but then create a separate PR for it as a followup

The related PR is #147637

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:X86 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[DAG] SDPatternMatch m_Zero/m_One/m_AllOnes have inconsistent undef handling [DAG] SDPatternMatch - m_Zero can't see through bitcasts
5 participants