Skip to content

Commit

Permalink
[DAG] Support saturated truncate
Browse files Browse the repository at this point in the history
`truncate` is `saturated` if no additional conversion is required
between the target and return values. if the target is `saturated`
when attempting to crop from a `vector`, there is an opportunity
to optimize it.

previously, each architecture had an attemping optimization, so there
was redundant code.

this patch implements common logic by adding `ISD::TRUNCATE_[US]SAT`
to indicate saturated truncate.

Fixes llvm#85903
  • Loading branch information
ParkHanbum committed Jul 17, 2024
1 parent b042af3 commit 66b9332
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 1 deletion.
3 changes: 3 additions & 0 deletions llvm/include/llvm/CodeGen/ISDOpcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,9 @@ enum NodeType {

/// TRUNCATE - Completely drop the high bits.
TRUNCATE,
/// TRUNCATE_[SU]SAT - Truncate for saturated operand
TRUNCATE_SSAT,
TRUNCATE_USAT,

/// [SU]INT_TO_FP - These operators convert integers (whose interpreted sign
/// depends on the first letter) to floating point.
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/Target/TargetSelectionDAG.td
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,8 @@ def sext : SDNode<"ISD::SIGN_EXTEND", SDTIntExtendOp>;
def zext : SDNode<"ISD::ZERO_EXTEND", SDTIntExtendOp>;
def anyext : SDNode<"ISD::ANY_EXTEND" , SDTIntExtendOp>;
def trunc : SDNode<"ISD::TRUNCATE" , SDTIntTruncOp>;
def truncssat : SDNode<"ISD::TRUNCATE_SSAT", SDTIntTruncOp>;
def truncusat : SDNode<"ISD::TRUNCATE_USAT", SDTIntTruncOp>;
def bitconvert : SDNode<"ISD::BITCAST" , SDTUnaryOp>;
def addrspacecast : SDNode<"ISD::ADDRSPACECAST", SDTUnaryOp>;
def freeze : SDNode<"ISD::FREEZE" , SDTFreeze>;
Expand Down
132 changes: 131 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,8 @@ namespace {
SDValue visitSIGN_EXTEND_INREG(SDNode *N);
SDValue visitEXTEND_VECTOR_INREG(SDNode *N);
SDValue visitTRUNCATE(SDNode *N);
SDValue visitTRUNCATE_SSAT(SDNode *N);
SDValue visitTRUNCATE_USAT(SDNode *N);
SDValue visitBITCAST(SDNode *N);
SDValue visitFREEZE(SDNode *N);
SDValue visitBUILD_PAIR(SDNode *N);
Expand Down Expand Up @@ -1907,6 +1909,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::ZERO_EXTEND_VECTOR_INREG:
case ISD::ANY_EXTEND_VECTOR_INREG: return visitEXTEND_VECTOR_INREG(N);
case ISD::TRUNCATE: return visitTRUNCATE(N);
case ISD::TRUNCATE_SSAT: return visitTRUNCATE_SSAT(N);
case ISD::TRUNCATE_USAT: return visitTRUNCATE_USAT(N);
case ISD::BITCAST: return visitBITCAST(N);
case ISD::BUILD_PAIR: return visitBUILD_PAIR(N);
case ISD::FADD: return visitFADD(N);
Expand Down Expand Up @@ -13154,7 +13158,8 @@ SDValue DAGCombiner::matchVSelectOpSizesWithSetCC(SDNode *Cast) {
unsigned CastOpcode = Cast->getOpcode();
assert((CastOpcode == ISD::SIGN_EXTEND || CastOpcode == ISD::ZERO_EXTEND ||
CastOpcode == ISD::TRUNCATE || CastOpcode == ISD::FP_EXTEND ||
CastOpcode == ISD::FP_ROUND) &&
CastOpcode == ISD::TRUNCATE_SSAT ||
CastOpcode == ISD::TRUNCATE_USAT || CastOpcode == ISD::FP_ROUND) &&
"Unexpected opcode for vector select narrowing/widening");

// We only do this transform before legal ops because the pattern may be
Expand Down Expand Up @@ -14867,13 +14872,138 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
return SDValue();
}

SDValue DAGCombiner::visitTRUNCATE_USAT(SDNode *N) {
EVT VT = N->getValueType(0);
SDValue N0 = N->getOperand(0);
SDValue FPInstr = N0.getOpcode() == ISD::SMAX ? N0.getOperand(0) : N0;
if (FPInstr.getOpcode() == ISD::FP_TO_SINT ||
FPInstr.getOpcode() == ISD::FP_TO_UINT) {
EVT FPVT = FPInstr.getOperand(0).getValueType();
if (!DAG.getTargetLoweringInfo().shouldConvertFpToSat(ISD::FP_TO_UINT_SAT,
FPVT, VT))
return SDValue();
SDValue Sat = DAG.getNode(ISD::FP_TO_UINT_SAT, SDLoc(FPInstr), VT,
FPInstr.getOperand(0),
DAG.getValueType(VT.getScalarType()));
return Sat;
}

return SDValue();
}

SDValue DAGCombiner::visitTRUNCATE_SSAT(SDNode *N) { return SDValue(); }

/// Detect patterns of truncation with unsigned saturation:
///
/// 1. (truncate (umin (x, unsigned_max_of_dest_type)) to dest_type).
/// Return the source value x to be truncated or SDValue() if the pattern was
/// not matched.
///
/// 2. (truncate (smin (smax (x, C1), C2)) to dest_type),
/// where C1 >= 0 and C2 is unsigned max of destination type.
///
/// (truncate (smax (smin (x, C2), C1)) to dest_type)
/// where C1 >= 0, C2 is unsigned max of destination type and C1 <= C2.
///
/// These two patterns are equivalent to:
/// (truncate (umin (smax(x, C1), unsigned_max_of_dest_type)) to dest_type)
/// So return the smax(x, C1) value to be truncated or SDValue() if the
/// pattern was not matched.
static SDValue detectUSatPattern(SDValue In, EVT VT, SelectionDAG &DAG,
const SDLoc &DL) {
EVT InVT = In.getValueType();

// Saturation with truncation. We truncate from InVT to VT.
assert(InVT.getScalarSizeInBits() > VT.getScalarSizeInBits() &&
"Unexpected types for truncate operation");

// Match min/max and return limit value as a parameter.
auto MatchMinMax = [](SDValue V, unsigned Opcode, APInt &Limit) -> SDValue {
if (V.getOpcode() == Opcode &&
ISD::isConstantSplatVector(V.getOperand(1).getNode(), Limit))
return V.getOperand(0);
return SDValue();
};

APInt C1, C2;
if (SDValue UMin = MatchMinMax(In, ISD::UMIN, C2))
// C2 should be equal to UINT32_MAX / UINT16_MAX / UINT8_MAX according
// the element size of the destination type.
if (C2.isMask(VT.getScalarSizeInBits()))
return UMin;

if (SDValue SMin = MatchMinMax(In, ISD::SMIN, C2))
if (MatchMinMax(SMin, ISD::SMAX, C1))
if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()))
return SMin;

if (SDValue SMax = MatchMinMax(In, ISD::SMAX, C1))
if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, C2))
if (C1.isNonNegative() && C2.isMask(VT.getScalarSizeInBits()) &&
C2.uge(C1))
return DAG.getNode(ISD::SMAX, DL, InVT, SMin, In.getOperand(1));

return SDValue();
}

/// Detect patterns of truncation with signed saturation:
/// (truncate (smin ((smax (x, signed_min_of_dest_type)),
/// signed_max_of_dest_type)) to dest_type)
/// or:
/// (truncate (smax ((smin (x, signed_max_of_dest_type)),
/// signed_min_of_dest_type)) to dest_type).
/// With MatchPackUS, the smax/smin range is [0, unsigned_max_of_dest_type].
/// Return the source value to be truncated or SDValue() if the pattern was not
/// matched.
static SDValue detectSSatPattern(SDValue In, EVT VT) {
unsigned NumDstBits = VT.getScalarSizeInBits();
unsigned NumSrcBits = In.getScalarValueSizeInBits();
assert(NumSrcBits > NumDstBits && "Unexpected types for truncate operation");

auto MatchMinMax = [](SDValue V, unsigned Opcode,
const APInt &Limit) -> SDValue {
APInt C;
if (V.getOpcode() == Opcode &&
ISD::isConstantSplatVector(V.getOperand(1).getNode(), C) && C == Limit)
return V.getOperand(0);
return SDValue();
};

APInt SignedMax, SignedMin;
SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
if (SDValue SMin = MatchMinMax(In, ISD::SMIN, SignedMax)) {
if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, SignedMin)) {
return SMax;
}
}
if (SDValue SMax = MatchMinMax(In, ISD::SMAX, SignedMin)) {
if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, SignedMax)) {
return SMin;
}
}
return SDValue();
}

SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
EVT SrcVT = N0.getValueType();
bool isLE = DAG.getDataLayout().isLittleEndian();
SDLoc DL(N);

if (!LegalOperations && N->getOpcode() == ISD::TRUNCATE) {
if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_SSAT, SrcVT)) {
if (SDValue SSatVal = detectSSatPattern(N0, VT))
return DAG.getNode(ISD::TRUNCATE_SSAT, DL, VT, SSatVal);
}

if (TLI.isOperationLegalOrCustom(ISD::TRUNCATE_USAT, SrcVT)) {
if (SDValue USatVal = detectUSatPattern(N0, VT, DAG, DL))
return DAG.getNode(ISD::TRUNCATE_USAT, DL, VT, USatVal);
}
}

// trunc(undef) = undef
if (N0.isUndef())
return DAG.getUNDEF(VT);
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::SIGN_EXTEND_VECTOR_INREG: return "sign_extend_vector_inreg";
case ISD::ZERO_EXTEND_VECTOR_INREG: return "zero_extend_vector_inreg";
case ISD::TRUNCATE: return "truncate";
case ISD::TRUNCATE_SSAT: return "truncate_ssat";
case ISD::TRUNCATE_USAT: return "truncate_usat";
case ISD::FP_ROUND: return "fp_round";
case ISD::STRICT_FP_ROUND: return "strict_fp_round";
case ISD::FP_EXTEND: return "fp_extend";
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/CodeGen/TargetLoweringBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,10 @@ void TargetLoweringBase::initActions() {
// Absolute difference
setOperationAction({ISD::ABDS, ISD::ABDU}, VT, Expand);

// Saturated trunc
setOperationAction(ISD::TRUNCATE_SSAT, VT, Expand);
setOperationAction(ISD::TRUNCATE_USAT, VT, Expand);

// These default to Expand so they will be expanded to CTLZ/CTTZ by default.
setOperationAction({ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT,
Expand);
Expand Down

0 comments on commit 66b9332

Please sign in to comment.