@@ -50995,38 +50995,31 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
50995
50995
/// pattern was not matched.
50996
50996
static SDValue detectUSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
50997
50997
const SDLoc &DL) {
50998
+ using namespace llvm::SDPatternMatch;
50998
50999
EVT InVT = In.getValueType();
50999
51000
51000
51001
// Saturation with truncation. We truncate from InVT to VT.
51001
51002
assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
51002
51003
"Unexpected types for truncate operation");
51003
51004
51004
- // Match min/max and return limit value as a parameter.
51005
- auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
51006
- if (V.getOpcode() == Opcode &&
51007
- ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
51008
- return V.getOperand(0);
51009
- return SDValue();
51010
- };
51011
-
51012
51005
APInt C1, C2;
51013
- if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2))
51014
- // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
51015
- // the element size of the destination type.
51016
- if (C2.isMask(VT.getScalarSizeInBits()))
51017
- return UMin;
51006
+ SDValue UMin, SMin, SMax;
51018
51007
51019
- if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2))
51020
- if (MatchMinMax(SMin, ISD::SMAX, C1))
51021
- if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
51022
- return SMin;
51008
+ // C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
51009
+ // the element size of the destination type.
51010
+ if (sd_match(In, m_UMin(m_Value(UMin), m_ConstInt(C2))) &&
51011
+ C2.isMask(VT.getScalarSizeInBits()))
51012
+ return UMin;
51023
51013
51024
- if (SDValue SMax = MatchMinMax(In, ISD::SMAX, C1))
51025
- if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, C2))
51026
- if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) &&
51027
- C2.uge(C1)) {
51028
- return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));
51029
- }
51014
+ if (sd_match(In, m_SMin(m_Value(SMin), m_ConstInt(C2))) &&
51015
+ sd_match(SMin, m_SMax(m_Value(SMax), m_ConstInt(C1))) &&
51016
+ C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
51017
+ return SMin;
51018
+
51019
+ if (sd_match(In, m_SMax(m_Value(SMax), m_ConstInt(C1))) &&
51020
+ sd_match(SMax, m_SMin(m_Value(SMin), m_ConstInt(C2))) &&
51021
+ C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) && C2.uge(C1))
51022
+ return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));
51030
51023
51031
51024
return SDValue();
51032
51025
}
@@ -51041,35 +51034,28 @@ static SDValue detectUSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
51041
51034
/// Return the source value to be truncated or SDValue() if the pattern was not
51042
51035
/// matched.
51043
51036
static SDValue detectSSatPattern(SDValue In, EVT VT, bool MatchPackUS = false) {
51037
+ using namespace llvm::SDPatternMatch;
51044
51038
unsigned NumDstBits = VT.getScalarSizeInBits();
51045
51039
unsigned NumSrcBits = In.getScalarValueSizeInBits();
51046
51040
assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");
51047
51041
51048
- auto MatchMinMax = [](SDValue V, unsigned Opcode,
51049
- const APInt &Limit) -> SDValue {
51050
- APInt C;
51051
- if (V.getOpcode() == Opcode &&
51052
- ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
51053
- return V.getOperand(0);
51054
- return SDValue();
51055
- };
51056
-
51057
51042
APInt SignedMax, SignedMin;
51058
51043
if (MatchPackUS) {
51059
51044
SignedMax = APInt::getAllOnes(NumDstBits).zext(NumSrcBits);
51060
- SignedMin = APInt(NumSrcBits, 0 );
51045
+ SignedMin = APInt::getZero (NumSrcBits);
51061
51046
} else {
51062
51047
SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
51063
51048
SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
51064
51049
}
51065
51050
51066
- if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax))
51067
- if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin))
51068
- return SMax;
51051
+ SDValue SMin, SMax;
51052
+ if (sd_match(In, m_SMin(m_Value(SMin), m_SpecificInt(SignedMax))) &&
51053
+ sd_match(SMin, m_SMax(m_Value(SMax), m_SpecificInt(SignedMin))))
51054
+ return SMax;
51069
51055
51070
- if (SDValue SMax = MatchMinMax (In, ISD::SMAX, SignedMin))
51071
- if (SDValue SMin = MatchMinMax (SMax, ISD::SMIN, SignedMax))
51072
- return SMin;
51056
+ if (sd_match (In, m_SMax(m_Value(SMax), m_SpecificInt( SignedMin))) &&
51057
+ sd_match (SMax, m_SMin(m_Value(SMin), m_SpecificInt( SignedMax)) ))
51058
+ return SMin;
51073
51059
51074
51060
return SDValue();
51075
51061
}
0 commit comments