Skip to content

Commit

Permalink
[X86] Convert truncsat clamping patterns to use SDPatternMatch. NFC.
Browse files Browse the repository at this point in the history
Inspired by #99418 (which hopefully we can replace this code with at some point)
  • Loading branch information
RKSimon committed Aug 9, 2024
1 parent f4d5b14 commit 669d844
Showing 1 changed file with 25 additions and 39 deletions.
64 changes: 25 additions & 39 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50995,38 +50995,31 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
/// pattern was not matched.
static SDValue detectUSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
const SDLoc &DL) {
using namespace llvm::SDPatternMatch;
EVT InVT = In.getValueType();

// Saturation with truncation. We truncate from InVT to VT.
assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
"Unexpected types for truncate operation");

// Match min/max and return limit value as a parameter.
auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
if (V.getOpcode() == Opcode &&
ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
return V.getOperand(0);
return SDValue();
};

APInt C1, C2;
if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2))
// C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
// the element size of the destination type.
if (C2.isMask(VT.getScalarSizeInBits()))
return UMin;
SDValue UMin, SMin, SMax;

if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2))
if (MatchMinMax(SMin, ISD::SMAX, C1))
if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
return SMin;
// C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
// the element size of the destination type.
if (sd_match(In, m_UMin(m_Value(UMin), m_ConstInt(C2))) &&
C2.isMask(VT.getScalarSizeInBits()))
return UMin;

if (SDValue SMax = MatchMinMax(In, ISD::SMAX, C1))
if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, C2))
if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) &&
C2.uge(C1)) {
return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));
}
if (sd_match(In, m_SMin(m_Value(SMin), m_ConstInt(C2))) &&
sd_match(SMin, m_SMax(m_Value(SMax), m_ConstInt(C1))) &&
C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
return SMin;

if (sd_match(In, m_SMax(m_Value(SMax), m_ConstInt(C1))) &&
sd_match(SMax, m_SMin(m_Value(SMin), m_ConstInt(C2))) &&
C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) && C2.uge(C1))
return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));

return SDValue();
}
Expand All @@ -51041,35 +51034,28 @@ static SDValue detectUSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
/// Return the source value to be truncated or SDValue() if the pattern was not
/// matched.
static SDValue detectSSatPattern(SDValue In, EVT VT, bool MatchPackUS = false) {
using namespace llvm::SDPatternMatch;
unsigned NumDstBits = VT.getScalarSizeInBits();
unsigned NumSrcBits = In.getScalarValueSizeInBits();
assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");

auto MatchMinMax = [](SDValue V, unsigned Opcode,
const APInt &Limit) -> SDValue {
APInt C;
if (V.getOpcode() == Opcode &&
ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
return V.getOperand(0);
return SDValue();
};

APInt SignedMax, SignedMin;
if (MatchPackUS) {
SignedMax = APInt::getAllOnes(NumDstBits).zext(NumSrcBits);
SignedMin = APInt(NumSrcBits, 0);
SignedMin = APInt::getZero(NumSrcBits);
} else {
SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
}

if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax))
if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin))
return SMax;
SDValue SMin, SMax;
if (sd_match(In, m_SMin(m_Value(SMin), m_SpecificInt(SignedMax))) &&
sd_match(SMin, m_SMax(m_Value(SMax), m_SpecificInt(SignedMin))))
return SMax;

if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin))
if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax))
return SMin;
if (sd_match(In, m_SMax(m_Value(SMax), m_SpecificInt(SignedMin))) &&
sd_match(SMax, m_SMin(m_Value(SMin), m_SpecificInt(SignedMax))))
return SMin;

return SDValue();
}
Expand Down

0 comments on commit 669d844

Please sign in to comment.