Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 109 additions & 39 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,7 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
ISD::SCALAR_TO_VECTOR,
ISD::ZERO_EXTEND,
ISD::SIGN_EXTEND_INREG,
ISD::ANY_EXTEND,
ISD::EXTRACT_VECTOR_ELT,
ISD::INSERT_VECTOR_ELT,
ISD::FCOPYSIGN});
Expand Down Expand Up @@ -13289,6 +13290,20 @@ static uint32_t getPermuteMask(SDValue V) {
return ~0;
}

static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI);

SDValue SITargetLowering::performLeftShiftCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
if (DCI.getDAGCombineLevel() < AfterLegalizeTypes)
return SDValue();

EVT VT = N->getValueType(0);
if (VT != MVT::i32)
return SDValue();

return matchPERM(N, DCI);
}

SDValue SITargetLowering::performAndCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
if (DCI.isBeforeLegalize())
Expand Down Expand Up @@ -13874,6 +13889,9 @@ static bool addresses16Bits(int Mask) {
int Low8 = Mask & 0xff;
int Hi8 = (Mask & 0xff00) >> 8;

if (Hi8 == 0x0c || Low8 == 0x0c)
return false;

assert(Low8 < 8 && Hi8 < 8);
// Are the bytes contiguous in the order of increasing addresses.
bool IsConsecutive = (Hi8 - Low8 == 1);
Expand Down Expand Up @@ -13968,58 +13986,70 @@ static SDValue getDWordFromOffset(SelectionDAG &DAG, SDLoc SL, SDValue Src,

static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
SelectionDAG &DAG = DCI.DAG;
assert(!DAG.getDataLayout().isBigEndian());

[[maybe_unused]] EVT VT = N->getValueType(0);
SmallVector<ByteProvider<SDValue>, 8> PermNodes;
SmallVector<ByteProvider<SDValue>, 4> PermNodes;

// VT is known to be MVT::i32, so we need to provide 4 bytes.
assert(VT == MVT::i32);
for (int i = 0; i < 4; i++) {
// Find the ByteProvider that provides the ith byte of the result of OR
std::optional<ByteProvider<SDValue>> P =
calculateByteProvider(SDValue(N, 0), i, 0, /*StartingIndex = */ i);
// TODO support constantZero
if (!P || P->isConstantZero())
if (!P)
return SDValue();

PermNodes.push_back(*P);
}
if (PermNodes.size() != 4)
return SDValue();

std::pair<unsigned, unsigned> FirstSrc(0, PermNodes[0].SrcOffset / 4);
std::optional<std::pair<unsigned, unsigned>> SecondSrc;
static auto isSameSrc = [](SDValue SrcA, unsigned DWordA, SDValue SrcB,
unsigned DWordB) {
// If the Src uses a byte from a different DWORD, then it corresponds
// with a difference source
return SrcA == SrcB && DWordA == DWordB;
};

SDValue Src0, Src1;
unsigned DWord0, DWord1;
uint64_t PermMask = 0x00000000;
for (size_t i = 0; i < PermNodes.size(); i++) {
auto PermOp = PermNodes[i];
// Since the mask is applied to Src1:Src2, Src1 bytes must be offset
// by sizeof(Src2) = 4
int SrcByteAdjust = 4;
ByteProvider<SDValue> PermOp = PermNodes[i];
if (PermOp.isConstantZero()) {
PermMask |= 0x0c << (i * 8);
continue;
}

// If the Src uses a byte from a different DWORD, then it corresponds
// with a difference source
if (!PermOp.hasSameSrc(PermNodes[FirstSrc.first]) ||
((PermOp.SrcOffset / 4) != FirstSrc.second)) {
if (SecondSrc)
if (!PermOp.hasSameSrc(PermNodes[SecondSrc->first]) ||
((PermOp.SrcOffset / 4) != SecondSrc->second))
return SDValue();
const SDValue SrcI = PermOp.Src.value();
const unsigned DWordI = PermOp.SrcOffset / 4;
const unsigned ByteI = PermOp.SrcOffset % 4;
if (!Src0) {
Src0 = SrcI;
DWord0 = DWordI;
}

// Set the index of the second distinct Src node
SecondSrc = {i, PermNodes[i].SrcOffset / 4};
assert(!(PermNodes[SecondSrc->first].Src->getValueSizeInBits() % 8));
SrcByteAdjust = 0;
if (!isSameSrc(Src0, DWord0, SrcI, DWordI)) {
if (!Src1) {
Src1 = SrcI;
DWord1 = DWordI;
} else if (!isSameSrc(Src1, DWord1, SrcI, DWordI))
return SDValue();
}
assert((PermOp.SrcOffset % 4) + SrcByteAdjust < 8);
assert(!DAG.getDataLayout().isBigEndian());
PermMask |= ((PermOp.SrcOffset % 4) + SrcByteAdjust) << (i * 8);

// Since the mask is applied to Src0:Src1, Src0 bytes must be offset
// by sizeof(Src1) = 4
const int SrcByteAdjust = SrcI == Src0 ? 4 : 0;
assert(ByteI + SrcByteAdjust < 8);
PermMask |= (ByteI + SrcByteAdjust) << (i * 8);
}

SDLoc DL(N);
SDValue Op = *PermNodes[FirstSrc.first].Src;
Op = getDWordFromOffset(DAG, DL, Op, FirstSrc.second);
SDValue Op = Src0;
Op = getDWordFromOffset(DAG, DL, Op, DWord0);
assert(Op.getValueSizeInBits() == 32);

// Check that we are not just extracting the bytes in order from an op
if (!SecondSrc) {
if (!Src1) {
int Low16 = PermMask & 0xffff;
int Hi16 = (PermMask & 0xffff0000) >> 16;

Expand All @@ -14031,12 +14061,12 @@ static SDValue matchPERM(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return DAG.getBitcast(MVT::getIntegerVT(32), Op);
}

SDValue OtherOp = SecondSrc ? *PermNodes[SecondSrc->first].Src : Op;

if (SecondSrc) {
OtherOp = getDWordFromOffset(DAG, DL, OtherOp, SecondSrc->second);
SDValue OtherOp;
if (Src1) {
OtherOp = getDWordFromOffset(DAG, DL, Src1, DWord1);
assert(OtherOp.getValueSizeInBits() == 32);
}
} else
OtherOp = Op;

if (hasNon16BitAccesses(PermMask, Op, OtherOp)) {

Expand Down Expand Up @@ -14315,10 +14345,11 @@ SDValue SITargetLowering::performXorCombine(SDNode *N,
return SDValue();
}

SDValue SITargetLowering::performZeroExtendCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
SDValue
SITargetLowering::performZeroOrAnyExtendCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
if (!Subtarget->has16BitInsts() ||
DCI.getDAGCombineLevel() < AfterLegalizeDAG)
DCI.getDAGCombineLevel() < AfterLegalizeTypes)
return SDValue();

EVT VT = N->getValueType(0);
Expand All @@ -14329,7 +14360,41 @@ SDValue SITargetLowering::performZeroExtendCombine(SDNode *N,
if (Src.getValueType() != MVT::i16)
return SDValue();

return SDValue();
// TODO: We bail out below if SrcOffset is not in the first dword (>= 4). It's
// possible we're missing out on some combine opportunities, but we'd need to
// weigh the cost of extracting the byte from the upper dwords.

std::optional<ByteProvider<SDValue>> BP0 =
calculateByteProvider(SDValue(N, 0), 0, 0, 0);
if (!BP0.has_value() || 4 <= BP0->SrcOffset)
return SDValue();
SDValue V0 = BP0->Src.value_or(SDValue());

std::optional<ByteProvider<SDValue>> BP1 =
calculateByteProvider(SDValue(N, 0), 1, 0, 1);
if (!BP1.has_value() || 4 <= BP1->SrcOffset)
return SDValue();
SDValue V1 = BP1->Src.value_or(SDValue());

if (!V0 || !V1 || V0 == V1)
return SDValue();

SelectionDAG &DAG = DCI.DAG;
SDLoc DL(N);
uint32_t PermMask = 0x0c0c0c0c;
if (V0) {
V0 = DAG.getBitcastedAnyExtOrTrunc(V0, DL, MVT::i32);
PermMask = (PermMask & ~0xFF) | (BP0->SrcOffset + 4);
}

if (V1) {
V1 = DAG.getBitcastedAnyExtOrTrunc(V1, DL, MVT::i32);
PermMask = (PermMask & ~(0xFF << 8)) | (BP1->SrcOffset << 8);
}

SDValue P = DAG.getNode(AMDGPUISD::PERM, DL, MVT::i32, V0, V1,
DAG.getConstant(PermMask, DL, MVT::i32));
return P;
}

SDValue
Expand Down Expand Up @@ -16997,6 +17062,10 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
return performMinMaxCombine(N, DCI);
case ISD::FMA:
return performFMACombine(N, DCI);

case ISD::SHL:
return performLeftShiftCombine(N, DCI);

case ISD::AND:
return performAndCombine(N, DCI);
case ISD::OR:
Expand All @@ -17011,8 +17080,9 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
}
case ISD::XOR:
return performXorCombine(N, DCI);
case ISD::ANY_EXTEND:
case ISD::ZERO_EXTEND:
return performZeroExtendCombine(N, DCI);
return performZeroOrAnyExtendCombine(N, DCI);
case ISD::SIGN_EXTEND_INREG:
return performSignExtendInRegCombine(N, DCI);
case AMDGPUISD::FP_CLASS:
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/AMDGPU/SIISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,11 @@ class SITargetLowering final : public AMDGPUTargetLowering {
unsigned Opc, SDValue LHS,
const ConstantSDNode *CRHS) const;

SDValue performLeftShiftCombine(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue performAndCombine(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue performOrCombine(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue performXorCombine(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue performZeroExtendCombine(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue performZeroOrAnyExtendCombine(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue performSignExtendInRegCombine(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue performClassCombine(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue getCanonicalConstantFP(SelectionDAG &DAG, const SDLoc &SL, EVT VT,
Expand Down