Skip to content

Commit e71c622

Browse files
committed
[DAG] SDPatternMatch m_Zero/m_One/m_AllOnes have inconsistent undef handling
1 parent 4a6d78e commit e71c622

File tree

5 files changed

+59
-18
lines changed

5 files changed

+59
-18
lines changed

llvm/include/llvm/CodeGen/SDPatternMatch.h

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,19 +1100,46 @@ inline SpecificInt_match m_SpecificInt(uint64_t V) {
11001100
return SpecificInt_match(APInt(64, V));
11011101
}
11021102

1103-
inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); }
1104-
inline SpecificInt_match m_One() { return m_SpecificInt(1U); }
1103+
struct Zero_match {
1104+
bool AllowUndefs;
1105+
1106+
explicit Zero_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1107+
1108+
template <typename MatchContext>
1109+
bool match(const MatchContext &, SDValue N) const {
1110+
return isZeroOrZeroSplat(N, AllowUndefs);
1111+
}
1112+
};
1113+
1114+
struct Ones_match {
1115+
bool AllowUndefs;
1116+
1117+
Ones_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1118+
1119+
template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1120+
return isOnesOrOnesSplat(N, AllowUndefs);
1121+
}
1122+
};
11051123

11061124
struct AllOnes_match {
1125+
bool AllowUndefs;
11071126

1108-
AllOnes_match() = default;
1127+
AllOnes_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
11091128

11101129
template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1111-
return isAllOnesOrAllOnesSplat(N);
1130+
return isAllOnesOrAllOnesSplat(N, AllowUndefs);
11121131
}
11131132
};
11141133

1115-
inline AllOnes_match m_AllOnes() { return AllOnes_match(); }
1134+
inline Ones_match m_One(bool AllowUndefs = false) {
1135+
return Ones_match(AllowUndefs);
1136+
}
1137+
inline Zero_match m_Zero(bool AllowUndefs = false) {
1138+
return Zero_match(AllowUndefs);
1139+
}
1140+
inline AllOnes_match m_AllOnes(bool AllowUndefs = false) {
1141+
return AllOnes_match(AllowUndefs);
1142+
}
11161143

11171144
/// Match true boolean value based on the information provided by
11181145
/// TargetLowering.
@@ -1189,7 +1216,7 @@ inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
11891216

11901217
/// Match a negate as a sub(0, v)
11911218
template <typename ValTy>
1192-
inline BinaryOpc_match<SpecificInt_match, ValTy> m_Neg(const ValTy &V) {
1219+
inline BinaryOpc_match<Zero_match, ValTy, false> m_Neg(const ValTy &V) {
11931220
return m_Sub(m_Zero(), V);
11941221
}
11951222

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,6 +1937,10 @@ LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);
19371937
/// Does not permit build vector implicit truncation.
19381938
LLVM_ABI bool isAllOnesOrAllOnesSplat(SDValue V, bool AllowUndefs = false);
19391939

1940+
LLVM_ABI bool isOnesOrOnesSplat(SDValue N, bool AllowUndefs = false);
1941+
1942+
LLVM_ABI bool isZeroOrZeroSplat(SDValue N, bool AllowUndefs = false);
1943+
19401944
/// Return true if \p V is either a integer or FP constant.
19411945
inline bool isIntOrFPConstant(SDValue V) {
19421946
return isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V);

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4281,7 +4281,7 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
42814281
return V;
42824282

42834283
// (A - B) - 1 -> add (xor B, -1), A
4284-
if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One())))
4284+
if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One(true))))
42854285
return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT));
42864286

42874287
// Look for:

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12569,6 +12569,18 @@ bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
1256912569
return C && C->isAllOnes() && C->getValueSizeInBits(0) == BitWidth;
1257012570
}
1257112571

12572+
bool llvm::isOnesOrOnesSplat(SDValue N, bool AllowUndefs) {
12573+
N = peekThroughBitcasts(N);
12574+
ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs);
12575+
return C && C->getAPIntValue() == 1;
12576+
}
12577+
12578+
bool llvm::isZeroOrZeroSplat(SDValue N, bool AllowUndefs) {
12579+
N = peekThroughBitcasts(N);
12580+
ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs, true);
12581+
return C && C->isZero();
12582+
}
12583+
1257212584
HandleSDNode::~HandleSDNode() {
1257312585
DropOperands();
1257412586
}

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57923,22 +57923,20 @@ static SDValue combineAdd(SDNode *N, SelectionDAG &DAG,
5792357923
}
5792457924
}
5792557925

57926+
SDValue X, Y;
57927+
5792657928
// add(psadbw(X,0),psadbw(Y,0)) -> psadbw(add(X,Y),0)
5792757929
// iff X and Y won't overflow.
57928-
if (Op0.getOpcode() == X86ISD::PSADBW && Op1.getOpcode() == X86ISD::PSADBW &&
57929-
ISD::isBuildVectorAllZeros(Op0.getOperand(1).getNode()) &&
57930-
ISD::isBuildVectorAllZeros(Op1.getOperand(1).getNode())) {
57931-
if (DAG.willNotOverflowAdd(false, Op0.getOperand(0), Op1.getOperand(0))) {
57932-
MVT OpVT = Op0.getOperand(1).getSimpleValueType();
57933-
SDValue Sum =
57934-
DAG.getNode(ISD::ADD, DL, OpVT, Op0.getOperand(0), Op1.getOperand(0));
57935-
return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
57936-
getZeroVector(OpVT, Subtarget, DAG, DL));
57937-
}
57930+
if (sd_match(Op0, m_c_BinOp(X86ISD::PSADBW, m_Value(X), m_Zero())) &&
57931+
sd_match(Op1, m_c_BinOp(X86ISD::PSADBW, m_Value(Y), m_Zero())) &&
57932+
DAG.willNotOverflowAdd(/*IsSigned=*/false, X, Y)) {
57933+
MVT OpVT = X.getSimpleValueType();
57934+
SDValue Sum = DAG.getNode(ISD::ADD, DL, OpVT, X, Y);
57935+
return DAG.getNode(X86ISD::PSADBW, DL, VT, Sum,
57936+
getZeroVector(OpVT, Subtarget, DAG, DL));
5793857937
}
5793957938

5794057939
if (VT.isVector()) {
57941-
SDValue X, Y;
5794257940
EVT BoolVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
5794357941
VT.getVectorElementCount());
5794457942

0 commit comments

Comments
 (0)