Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DAGCombiner] Attempt to fold 'add' nodes to funnel-shift or rotate #125612

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 58 additions & 58 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,14 +649,15 @@ namespace {
bool DemandHighBits = true);
SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
SDValue MatchRotatePosNeg(SDValue Shifted, SDValue Pos, SDValue Neg,
SDValue InnerPos, SDValue InnerNeg, bool HasPos,
unsigned PosOpcode, unsigned NegOpcode,
const SDLoc &DL);
SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
bool HasPos, unsigned PosOpcode,
unsigned NegOpcode, const SDLoc &DL);
SDValue MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos, SDValue Neg,
SDValue InnerPos, SDValue InnerNeg, bool HasPos,
unsigned PosOpcode, unsigned NegOpcode,
const SDLoc &DL);
SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
SDValue InnerPos, SDValue InnerNeg, bool FromAdd,
bool HasPos, unsigned PosOpcode,
unsigned NegOpcode, const SDLoc &DL);
SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
bool FromAdd);
SDValue MatchLoadCombine(SDNode *N);
SDValue mergeTruncStores(StoreSDNode *N);
SDValue reduceLoadWidth(SDNode *N);
Expand Down Expand Up @@ -2986,6 +2987,9 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
if (SDValue V = foldAddSubOfSignBit(N, DL, DAG))
return V;

if (SDValue V = MatchRotate(N0, N1, SDLoc(N), /*FromAdd=*/true))
return V;

// Try to match AVGFLOOR fixedwidth pattern
if (SDValue V = foldAddToAvg(N, DL))
return V;
Expand Down Expand Up @@ -8175,7 +8179,7 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
return V;

// See if this is some rotate idiom.
if (SDValue Rot = MatchRotate(N0, N1, DL))
if (SDValue Rot = MatchRotate(N0, N1, DL, /*FromAdd=*/false))
return Rot;

if (SDValue Load = MatchLoadCombine(N))
Expand Down Expand Up @@ -8364,7 +8368,7 @@ static SDValue extractShiftForRotate(SelectionDAG &DAG, SDValue OppShift,
// The IsRotate flag should be set when the LHS of both shifts is the same.
// Otherwise if matching a general funnel shift, it should be clear.
static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
SelectionDAG &DAG, bool IsRotate) {
SelectionDAG &DAG, bool IsRotate, bool FromAdd) {
const auto &TLI = DAG.getTargetLoweringInfo();
// If EltSize is a power of 2 then:
//
Expand Down Expand Up @@ -8403,7 +8407,7 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
// NOTE: We can only do this when matching operations which won't modify the
// least Log2(EltSize) significant bits and not a general funnel shift.
unsigned MaskLoBits = 0;
if (IsRotate && isPowerOf2_64(EltSize)) {
if (IsRotate && !FromAdd && isPowerOf2_64(EltSize)) {
unsigned Bits = Log2_64(EltSize);
unsigned NegBits = Neg.getScalarValueSizeInBits();
if (NegBits >= Bits) {
Expand Down Expand Up @@ -8486,22 +8490,21 @@ static bool matchRotateSub(SDValue Pos, SDValue Neg, unsigned EltSize,
// Neg with outer conversions stripped away.
SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
SDValue Neg, SDValue InnerPos,
SDValue InnerNeg, bool HasPos,
unsigned PosOpcode, unsigned NegOpcode,
const SDLoc &DL) {
// fold (or (shl x, (*ext y)),
// (srl x, (*ext (sub 32, y)))) ->
SDValue InnerNeg, bool FromAdd,
bool HasPos, unsigned PosOpcode,
unsigned NegOpcode, const SDLoc &DL) {
// fold (or/add (shl x, (*ext y)),
// (srl x, (*ext (sub 32, y)))) ->
// (rotl x, y) or (rotr x, (sub 32, y))
//
// fold (or (shl x, (*ext (sub 32, y))),
// (srl x, (*ext y))) ->
// fold (or/add (shl x, (*ext (sub 32, y))),
// (srl x, (*ext y))) ->
// (rotr x, y) or (rotl x, (sub 32, y))
EVT VT = Shifted.getValueType();
if (matchRotateSub(InnerPos, InnerNeg, VT.getScalarSizeInBits(), DAG,
/*IsRotate*/ true)) {
/*IsRotate*/ true, FromAdd))
return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, Shifted,
HasPos ? Pos : Neg);
}

return SDValue();
}
Expand All @@ -8514,30 +8517,30 @@ SDValue DAGCombiner::MatchRotatePosNeg(SDValue Shifted, SDValue Pos,
// TODO: Merge with MatchRotatePosNeg.
SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
SDValue Neg, SDValue InnerPos,
SDValue InnerNeg, bool HasPos,
unsigned PosOpcode, unsigned NegOpcode,
const SDLoc &DL) {
SDValue InnerNeg, bool FromAdd,
bool HasPos, unsigned PosOpcode,
unsigned NegOpcode, const SDLoc &DL) {
EVT VT = N0.getValueType();
unsigned EltBits = VT.getScalarSizeInBits();

// fold (or (shl x0, (*ext y)),
// (srl x1, (*ext (sub 32, y)))) ->
// fold (or/add (shl x0, (*ext y)),
// (srl x1, (*ext (sub 32, y)))) ->
// (fshl x0, x1, y) or (fshr x0, x1, (sub 32, y))
//
// fold (or (shl x0, (*ext (sub 32, y))),
// (srl x1, (*ext y))) ->
// fold (or/add (shl x0, (*ext (sub 32, y))),
// (srl x1, (*ext y))) ->
// (fshr x0, x1, y) or (fshl x0, x1, (sub 32, y))
if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1)) {
if (matchRotateSub(InnerPos, InnerNeg, EltBits, DAG, /*IsRotate*/ N0 == N1,
FromAdd))
return DAG.getNode(HasPos ? PosOpcode : NegOpcode, DL, VT, N0, N1,
HasPos ? Pos : Neg);
}

// Matching the shift+xor cases, we can't easily use the xor'd shift amount
// so for now just use the PosOpcode case if its legal.
// TODO: When can we use the NegOpcode case?
if (PosOpcode == ISD::FSHL && isPowerOf2_32(EltBits)) {
SDValue X;
// fold (or (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
// fold (or/add (shl x0, y), (srl (srl x1, 1), (xor y, 31)))
// -> (fshl x0, x1, y)
if (sd_match(N1, m_Srl(m_Value(X), m_One())) &&
sd_match(InnerNeg,
Expand All @@ -8546,7 +8549,7 @@ SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
return DAG.getNode(ISD::FSHL, DL, VT, N0, X, Pos);
}

// fold (or (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
// fold (or/add (shl (shl x0, 1), (xor y, 31)), (srl x1, y))
// -> (fshr x0, x1, y)
if (sd_match(N0, m_Shl(m_Value(X), m_One())) &&
sd_match(InnerPos,
Expand All @@ -8555,7 +8558,7 @@ SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
return DAG.getNode(ISD::FSHR, DL, VT, X, N1, Neg);
}

// fold (or (shl (add x0, x0), (xor y, 31)), (srl x1, y))
// fold (or/add (shl (add x0, x0), (xor y, 31)), (srl x1, y))
// -> (fshr x0, x1, y)
// TODO: Should add(x,x) -> shl(x,1) be a general DAG canonicalization?
if (sd_match(N0, m_Add(m_Value(X), m_Deferred(X))) &&
Expand All @@ -8569,11 +8572,12 @@ SDValue DAGCombiner::MatchFunnelPosNeg(SDValue N0, SDValue N1, SDValue Pos,
return SDValue();
}

// MatchRotate - Handle an 'or' of two operands. If this is one of the many
// idioms for rotate, and if the target supports rotation instructions, generate
// a rot[lr]. This also matches funnel shift patterns, similar to rotation but
// with different shifted sources.
SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
// MatchRotate - Handle an 'or' or 'add' of two operands. If this is one of the
// many idioms for rotate, and if the target supports rotation instructions,
// generate a rot[lr]. This also matches funnel shift patterns, similar to
// rotation but with different shifted sources.
SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL,
bool FromAdd) {
EVT VT = LHS.getValueType();

// The target must have at least one rotate/funnel flavor.
Expand All @@ -8600,9 +8604,9 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
if (LHS.getOpcode() == ISD::TRUNCATE && RHS.getOpcode() == ISD::TRUNCATE &&
LHS.getOperand(0).getValueType() == RHS.getOperand(0).getValueType()) {
assert(LHS.getValueType() == RHS.getValueType());
if (SDValue Rot = MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL)) {
if (SDValue Rot =
MatchRotate(LHS.getOperand(0), RHS.getOperand(0), DL, FromAdd))
return DAG.getNode(ISD::TRUNCATE, SDLoc(LHS), LHS.getValueType(), Rot);
}
}

// Match "(X shl/srl V1) & V2" where V2 may not be present.
Expand Down Expand Up @@ -8736,10 +8740,10 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
return SDValue(); // Requires funnel shift support.
}

// fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
// fold (or (shl x, C1), (srl x, C2)) -> (rotr x, C2)
// fold (or (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
// fold (or (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
// fold (or/add (shl x, C1), (srl x, C2)) -> (rotl x, C1)
// fold (or/add (shl x, C1), (srl x, C2)) -> (rotr x, C2)
// fold (or/add (shl x, C1), (srl y, C2)) -> (fshl x, y, C1)
// fold (or/add (shl x, C1), (srl y, C2)) -> (fshr x, y, C2)
// iff C1+C2 == EltSizeInBits
if (ISD::matchBinaryPredicate(LHSShiftAmt, RHSShiftAmt, MatchRotateSum)) {
SDValue Res;
Expand Down Expand Up @@ -8782,29 +8786,25 @@ SDValue DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL) {
}

if (IsRotate && (HasROTL || HasROTR)) {
SDValue TryL =
MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt, LExtOp0,
RExtOp0, HasROTL, ISD::ROTL, ISD::ROTR, DL);
if (TryL)
if (SDValue TryL = MatchRotatePosNeg(LHSShiftArg, LHSShiftAmt, RHSShiftAmt,
LExtOp0, RExtOp0, FromAdd, HasROTL,
ISD::ROTL, ISD::ROTR, DL))
return TryL;

SDValue TryR =
MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt, RExtOp0,
LExtOp0, HasROTR, ISD::ROTR, ISD::ROTL, DL);
if (TryR)
if (SDValue TryR = MatchRotatePosNeg(RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
RExtOp0, LExtOp0, FromAdd, HasROTR,
ISD::ROTR, ISD::ROTL, DL))
return TryR;
}

SDValue TryL =
MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt, RHSShiftAmt,
LExtOp0, RExtOp0, HasFSHL, ISD::FSHL, ISD::FSHR, DL);
if (TryL)
if (SDValue TryL = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, LHSShiftAmt,
RHSShiftAmt, LExtOp0, RExtOp0, FromAdd,
HasFSHL, ISD::FSHL, ISD::FSHR, DL))
return TryL;

SDValue TryR =
MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt, LHSShiftAmt,
RExtOp0, LExtOp0, HasFSHR, ISD::FSHR, ISD::FSHL, DL);
if (TryR)
if (SDValue TryR = MatchFunnelPosNeg(LHSShiftArg, RHSShiftArg, RHSShiftAmt,
LHSShiftAmt, RExtOp0, LExtOp0, FromAdd,
HasFSHR, ISD::FSHR, ISD::FSHL, DL))
return TryR;

return SDValue();
Expand Down
Loading