Skip to content

Commit

Permalink
[ISel] Port AArch64 SABD and UABD to DAGCombine
Browse files Browse the repository at this point in the history
This ports the AArch64 SABD and USBD over to DAG Combine, where they can be
used by more backends (notably MVE in a follow-up patch). The matching code
has changed very little, just to handle legal operations and types
differently. It selects from (ABS (SUB (EXTEND a), (EXTEND b))), producing
a ubds/abdu which is zexted to the original type.

Differential Revision: https://reviews.llvm.org/D91937
  • Loading branch information
davemgreen committed Jun 26, 2021
1 parent 8c2d462 commit 2887f14
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 63 deletions.
7 changes: 7 additions & 0 deletions llvm/include/llvm/CodeGen/ISDOpcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,13 @@ enum NodeType {
MULHU,
MULHS,

// ABDS/ABDU - Absolute difference - Return the absolute difference between
// two numbers interpreted as signed/unsigned.
// i.e trunc(abs(sext(Op0) - sext(Op1))) becomes abds(Op0, Op1)
// or trunc(abs(zext(Op0) - zext(Op1))) becomes abdu(Op0, Op1)
ABDS,
ABDU,

/// [US]{MIN/MAX} - Binary minimum or maximum of signed or unsigned
/// integers.
SMIN,
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/Target/TargetSelectionDAG.td
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ def mul : SDNode<"ISD::MUL" , SDTIntBinOp,
[SDNPCommutative, SDNPAssociative]>;
def mulhs : SDNode<"ISD::MULHS" , SDTIntBinOp, [SDNPCommutative]>;
def mulhu : SDNode<"ISD::MULHU" , SDTIntBinOp, [SDNPCommutative]>;
def abds : SDNode<"ISD::ABDS" , SDTIntBinOp, [SDNPCommutative]>;
def abdu : SDNode<"ISD::ABDU" , SDTIntBinOp, [SDNPCommutative]>;
def smullohi : SDNode<"ISD::SMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>;
def umullohi : SDNode<"ISD::UMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>;
def sdiv : SDNode<"ISD::SDIV" , SDTIntBinOp>;
Expand Down
38 changes: 38 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9071,6 +9071,40 @@ SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
return SDValue();
}

// Given a ABS node, detect the following pattern:
// (ABS (SUB (EXTEND a), (EXTEND b))).
// Generates UABD/SABD instruction.
static SDValue combineABSToABD(SDNode *N, SelectionDAG &DAG,
const TargetLowering &TLI) {
SDValue AbsOp1 = N->getOperand(0);
SDValue Op0, Op1;

if (AbsOp1.getOpcode() != ISD::SUB)
return SDValue();

Op0 = AbsOp1.getOperand(0);
Op1 = AbsOp1.getOperand(1);

unsigned Opc0 = Op0.getOpcode();
// Check if the operands of the sub are (zero|sign)-extended.
if (Opc0 != Op1.getOpcode() ||
(Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
return SDValue();

EVT VT1 = Op0.getOperand(0).getValueType();
EVT VT2 = Op1.getOperand(0).getValueType();
// Check if the operands are of same type and valid size.
unsigned ABDOpcode = (Opc0 == ISD::SIGN_EXTEND) ? ISD::ABDS : ISD::ABDU;
if (VT1 != VT2 || !TLI.isOperationLegalOrCustom(ABDOpcode, VT1))
return SDValue();

Op0 = Op0.getOperand(0);
Op1 = Op1.getOperand(0);
SDValue ABD =
DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1);
return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD);
}

SDValue DAGCombiner::visitABS(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
Expand All @@ -9084,6 +9118,10 @@ SDValue DAGCombiner::visitABS(SDNode *N) {
// fold (abs x) -> x iff not-negative
if (DAG.SignBitIsZero(N0))
return N0;

if (SDValue ABD = combineABSToABD(N, DAG, TLI))
return ABD;

return SDValue();
}

Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::MUL: return "mul";
case ISD::MULHU: return "mulhu";
case ISD::MULHS: return "mulhs";
case ISD::ABDS: return "abds";
case ISD::ABDU: return "abdu";
case ISD::SDIV: return "sdiv";
case ISD::UDIV: return "udiv";
case ISD::SREM: return "srem";
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/CodeGen/TargetLoweringBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,10 @@ void TargetLoweringBase::initActions() {
setOperationAction(ISD::SUBC, VT, Expand);
setOperationAction(ISD::SUBE, VT, Expand);

// Absolute difference
setOperationAction(ISD::ABDS, VT, Expand);
setOperationAction(ISD::ABDU, VT, Expand);

// These default to Expand so they will be expanded to CTLZ/CTTZ by default.
setOperationAction(ISD::CTLZ_ZERO_UNDEF, VT, Expand);
setOperationAction(ISD::CTTZ_ZERO_UNDEF, VT, Expand);
Expand Down
68 changes: 14 additions & 54 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1050,6 +1050,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::USUBSAT, VT, Legal);
}

for (MVT VT : {MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16,
MVT::v4i32}) {
setOperationAction(ISD::ABDS, VT, Legal);
setOperationAction(ISD::ABDU, VT, Legal);
}

// Vector reductions
for (MVT VT : { MVT::v4f16, MVT::v2f32,
MVT::v8f16, MVT::v4f32, MVT::v2f64 }) {
Expand Down Expand Up @@ -2116,8 +2122,6 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::CTPOP_MERGE_PASSTHRU)
MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU)
MAKE_CASE(AArch64ISD::INDEX_VECTOR)
MAKE_CASE(AArch64ISD::UABD)
MAKE_CASE(AArch64ISD::SABD)
MAKE_CASE(AArch64ISD::UADDLP)
MAKE_CASE(AArch64ISD::CALL_RVMARKER)
}
Expand Down Expand Up @@ -4082,8 +4086,8 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
}
case Intrinsic::aarch64_neon_sabd:
case Intrinsic::aarch64_neon_uabd: {
unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? AArch64ISD::UABD
: AArch64ISD::SABD;
unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? ISD::ABDU
: ISD::ABDS;
return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
Op.getOperand(2));
}
Expand Down Expand Up @@ -12099,8 +12103,8 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
SDValue UABDHigh8Op1 =
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
DAG.getConstant(8, DL, MVT::i64));
SDValue UABDHigh8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD,
DL, MVT::v8i8, UABDHigh8Op0, UABDHigh8Op1);
SDValue UABDHigh8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
UABDHigh8Op0, UABDHigh8Op1);
SDValue UABDL = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDHigh8);

// Second, create the node pattern of UABAL.
Expand All @@ -12110,8 +12114,8 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
SDValue UABDLo8Op1 =
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
DAG.getConstant(0, DL, MVT::i64));
SDValue UABDLo8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD,
DL, MVT::v8i8, UABDLo8Op0, UABDLo8Op1);
SDValue UABDLo8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
UABDLo8Op0, UABDLo8Op1);
SDValue ZExtUABD = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDLo8);
SDValue UABAL = DAG.getNode(ISD::ADD, DL, MVT::v8i16, UABDL, ZExtUABD);

Expand Down Expand Up @@ -12170,48 +12174,6 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
}

// Given a ABS node, detect the following pattern:
// (ABS (SUB (EXTEND a), (EXTEND b))).
// Generates UABD/SABD instruction.
static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
SDValue AbsOp1 = N->getOperand(0);
SDValue Op0, Op1;

if (AbsOp1.getOpcode() != ISD::SUB)
return SDValue();

Op0 = AbsOp1.getOperand(0);
Op1 = AbsOp1.getOperand(1);

unsigned Opc0 = Op0.getOpcode();
// Check if the operands of the sub are (zero|sign)-extended.
if (Opc0 != Op1.getOpcode() ||
(Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
return SDValue();

EVT VectorT1 = Op0.getOperand(0).getValueType();
EVT VectorT2 = Op1.getOperand(0).getValueType();
// Check if vectors are of same type and valid size.
uint64_t Size = VectorT1.getFixedSizeInBits();
if (VectorT1 != VectorT2 || (Size != 64 && Size != 128))
return SDValue();

// Check if vector element types are valid.
EVT VT1 = VectorT1.getVectorElementType();
if (VT1 != MVT::i8 && VT1 != MVT::i16 && VT1 != MVT::i32)
return SDValue();

Op0 = Op0.getOperand(0);
Op1 = Op1.getOperand(0);
unsigned ABDOpcode =
(Opc0 == ISD::SIGN_EXTEND) ? AArch64ISD::SABD : AArch64ISD::UABD;
SDValue ABD =
DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1);
return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD);
}

static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
Expand Down Expand Up @@ -14377,8 +14339,8 @@ static SDValue performExtendCombine(SDNode *N,
// helps the backend to decide that an sabdl2 would be useful, saving a real
// extract_high operation.
if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND &&
(N->getOperand(0).getOpcode() == AArch64ISD::UABD ||
N->getOperand(0).getOpcode() == AArch64ISD::SABD)) {
(N->getOperand(0).getOpcode() == ISD::ABDU ||
N->getOperand(0).getOpcode() == ISD::ABDS)) {
SDNode *ABDNode = N->getOperand(0).getNode();
SDValue NewABD =
tryCombineLongOpWithDup(Intrinsic::not_intrinsic, ABDNode, DCI, DAG);
Expand Down Expand Up @@ -16344,8 +16306,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
default:
LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
break;
case ISD::ABS:
return performABSCombine(N, DAG, DCI, Subtarget);
case ISD::ADD:
case ISD::SUB:
return performAddSubCombine(N, DCI, DAG);
Expand Down
4 changes: 0 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,6 @@ enum NodeType : unsigned {
SRHADD,
URHADD,

// Absolute difference
UABD,
SABD,

// Unsigned Add Long Pairwise
UADDLP,

Expand Down
7 changes: 2 additions & 5 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -579,14 +579,11 @@ def AArch64urhadd : SDNode<"AArch64ISD::URHADD", SDT_AArch64binvec>;
def AArch64shadd : SDNode<"AArch64ISD::SHADD", SDT_AArch64binvec>;
def AArch64uhadd : SDNode<"AArch64ISD::UHADD", SDT_AArch64binvec>;

def AArch64uabd_n : SDNode<"AArch64ISD::UABD", SDT_AArch64binvec>;
def AArch64sabd_n : SDNode<"AArch64ISD::SABD", SDT_AArch64binvec>;

def AArch64uabd : PatFrags<(ops node:$lhs, node:$rhs),
[(AArch64uabd_n node:$lhs, node:$rhs),
[(abdu node:$lhs, node:$rhs),
(int_aarch64_neon_uabd node:$lhs, node:$rhs)]>;
def AArch64sabd : PatFrags<(ops node:$lhs, node:$rhs),
[(AArch64sabd_n node:$lhs, node:$rhs),
[(abds node:$lhs, node:$rhs),
(int_aarch64_neon_sabd node:$lhs, node:$rhs)]>;

def AArch64uaddlp_n : SDNode<"AArch64ISD::UADDLP", SDT_AArch64uaddlp>;
Expand Down

0 comments on commit 2887f14

Please sign in to comment.