Skip to content

[AArch64][GlobalISel] Perfect Shuffles #106446

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[AArch64][GlobalISel] Perfect Shuffles
This is a port of the existing perfect shuffle generation code from SDAG,
geneticized to work for both SDAG and GISel.  I wrote it a while ago and it has
been sitting on my machine. It brings the codegen for certain shuffles inline
and avoids the need for generating a tbl and constant pool load.
  • Loading branch information
davemgreen committed Apr 22, 2025
commit 1d6f9c35c2afd0f4b201c09d344f4a82a04d86ef
17 changes: 17 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,23 @@ class MachineIRBuilder {
const SrcOp &Elt,
const SrcOp &Idx);

/// Build and insert \p Res = G_INSERT_VECTOR_ELT \p Val, \p Elt, \p Idx
///
/// \pre setBasicBlock or setMI must have been called.
/// \pre \p Res must be a generic virtual register with scalar type.
/// \pre \p Val must be a generic virtual register with vector type.
/// \pre \p Elt must be a generic virtual register with scalar type.
///
/// \return The newly created instruction.
MachineInstrBuilder buildInsertVectorElementConstant(const DstOp &Res,
const SrcOp &Val,
const SrcOp &Elt,
const int Idx) {
const TargetLowering *TLI = getMF().getSubtarget().getTargetLowering();
LLT IdxTy = TLI->getVectorIdxLLT(getDataLayout());
return buildInsertVectorElement(Res, Val, Elt, buildConstant(IdxTy, Idx));
}

/// Build and insert \p Res = G_EXTRACT_VECTOR_ELT \p Val, \p Idx
///
/// \pre setBasicBlock or setMI must have been called.
Expand Down
10 changes: 9 additions & 1 deletion llvm/lib/Target/AArch64/AArch64Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,13 @@ def shuf_to_ins: GICombineRule <
(apply [{ applyINS(*${root}, MRI, B, ${matchinfo}); }])
>;

def perfect_shuffle: GICombineRule <
(defs root:$root),
(match (G_SHUFFLE_VECTOR $dst, $src1, $src2, $mask):$root,
[{ return matchPerfectShuffle(*${root}, MRI); }]),
(apply [{ applyPerfectShuffle(*${root}, MRI, B); }])
>;

def vashr_vlshr_imm_matchdata : GIDefMatchData<"int64_t">;
def vashr_vlshr_imm : GICombineRule<
(defs root:$root, vashr_vlshr_imm_matchdata:$matchinfo),
Expand All @@ -173,7 +180,8 @@ def form_duplane : GICombineRule <
>;

def shuffle_vector_lowering : GICombineGroup<[dup, rev, ext, zip, uzp, trn, fullrev,
form_duplane, shuf_to_ins]>;
form_duplane, shuf_to_ins,
perfect_shuffle]>;

// Turn G_UNMERGE_VALUES -> G_EXTRACT_VECTOR_ELT's
def vector_unmerge_lowering : GICombineRule <
Expand Down
257 changes: 89 additions & 168 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13507,172 +13507,6 @@ static SDValue tryFormConcatFromShuffle(SDValue Op, SelectionDAG &DAG) {
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, V0, V1);
}

/// GeneratePerfectShuffle - Given an entry in the perfect-shuffle table, emit
/// the specified operations to build the shuffle. ID is the perfect-shuffle
//ID, V1 and V2 are the original shuffle inputs. PFEntry is the Perfect shuffle
//table entry and LHS/RHS are the immediate inputs for this stage of the
//shuffle.
static SDValue GeneratePerfectShuffle(unsigned ID, SDValue V1,
SDValue V2, unsigned PFEntry, SDValue LHS,
SDValue RHS, SelectionDAG &DAG,
const SDLoc &dl) {
unsigned OpNum = (PFEntry >> 26) & 0x0F;
unsigned LHSID = (PFEntry >> 13) & ((1 << 13) - 1);
unsigned RHSID = (PFEntry >> 0) & ((1 << 13) - 1);

enum {
OP_COPY = 0, // Copy, used for things like <u,u,u,3> to say it is <0,1,2,3>
OP_VREV,
OP_VDUP0,
OP_VDUP1,
OP_VDUP2,
OP_VDUP3,
OP_VEXT1,
OP_VEXT2,
OP_VEXT3,
OP_VUZPL, // VUZP, left result
OP_VUZPR, // VUZP, right result
OP_VZIPL, // VZIP, left result
OP_VZIPR, // VZIP, right result
OP_VTRNL, // VTRN, left result
OP_VTRNR, // VTRN, right result
OP_MOVLANE // Move lane. RHSID is the lane to move into
};

if (OpNum == OP_COPY) {
if (LHSID == (1 * 9 + 2) * 9 + 3)
return LHS;
assert(LHSID == ((4 * 9 + 5) * 9 + 6) * 9 + 7 && "Illegal OP_COPY!");
return RHS;
}

if (OpNum == OP_MOVLANE) {
// Decompose a PerfectShuffle ID to get the Mask for lane Elt
auto getPFIDLane = [](unsigned ID, int Elt) -> int {
assert(Elt < 4 && "Expected Perfect Lanes to be less than 4");
Elt = 3 - Elt;
while (Elt > 0) {
ID /= 9;
Elt--;
}
return (ID % 9 == 8) ? -1 : ID % 9;
};

// For OP_MOVLANE shuffles, the RHSID represents the lane to move into. We
// get the lane to move from the PFID, which is always from the
// original vectors (V1 or V2).
SDValue OpLHS = GeneratePerfectShuffle(
LHSID, V1, V2, PerfectShuffleTable[LHSID], LHS, RHS, DAG, dl);
EVT VT = OpLHS.getValueType();
assert(RHSID < 8 && "Expected a lane index for RHSID!");
unsigned ExtLane = 0;
SDValue Input;

// OP_MOVLANE are either D movs (if bit 0x4 is set) or S movs. D movs
// convert into a higher type.
if (RHSID & 0x4) {
int MaskElt = getPFIDLane(ID, (RHSID & 0x01) << 1) >> 1;
if (MaskElt == -1)
MaskElt = (getPFIDLane(ID, ((RHSID & 0x01) << 1) + 1) - 1) >> 1;
assert(MaskElt >= 0 && "Didn't expect an undef movlane index!");
ExtLane = MaskElt < 2 ? MaskElt : (MaskElt - 2);
Input = MaskElt < 2 ? V1 : V2;
if (VT.getScalarSizeInBits() == 16) {
Input = DAG.getBitcast(MVT::v2f32, Input);
OpLHS = DAG.getBitcast(MVT::v2f32, OpLHS);
} else {
assert(VT.getScalarSizeInBits() == 32 &&
"Expected 16 or 32 bit shuffle elemements");
Input = DAG.getBitcast(MVT::v2f64, Input);
OpLHS = DAG.getBitcast(MVT::v2f64, OpLHS);
}
} else {
int MaskElt = getPFIDLane(ID, RHSID);
assert(MaskElt >= 0 && "Didn't expect an undef movlane index!");
ExtLane = MaskElt < 4 ? MaskElt : (MaskElt - 4);
Input = MaskElt < 4 ? V1 : V2;
// Be careful about creating illegal types. Use f16 instead of i16.
if (VT == MVT::v4i16) {
Input = DAG.getBitcast(MVT::v4f16, Input);
OpLHS = DAG.getBitcast(MVT::v4f16, OpLHS);
}
}
SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
Input.getValueType().getVectorElementType(),
Input, DAG.getVectorIdxConstant(ExtLane, dl));
SDValue Ins =
DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, Input.getValueType(), OpLHS,
Ext, DAG.getVectorIdxConstant(RHSID & 0x3, dl));
return DAG.getBitcast(VT, Ins);
}

SDValue OpLHS, OpRHS;
OpLHS = GeneratePerfectShuffle(LHSID, V1, V2, PerfectShuffleTable[LHSID], LHS,
RHS, DAG, dl);
OpRHS = GeneratePerfectShuffle(RHSID, V1, V2, PerfectShuffleTable[RHSID], LHS,
RHS, DAG, dl);
EVT VT = OpLHS.getValueType();

switch (OpNum) {
default:
llvm_unreachable("Unknown shuffle opcode!");
case OP_VREV:
// VREV divides the vector in half and swaps within the half.
if (VT.getVectorElementType() == MVT::i32 ||
VT.getVectorElementType() == MVT::f32)
return DAG.getNode(AArch64ISD::REV64, dl, VT, OpLHS);
// vrev <4 x i16> -> REV32
if (VT.getVectorElementType() == MVT::i16 ||
VT.getVectorElementType() == MVT::f16 ||
VT.getVectorElementType() == MVT::bf16)
return DAG.getNode(AArch64ISD::REV32, dl, VT, OpLHS);
// vrev <4 x i8> -> REV16
assert(VT.getVectorElementType() == MVT::i8);
return DAG.getNode(AArch64ISD::REV16, dl, VT, OpLHS);
case OP_VDUP0:
case OP_VDUP1:
case OP_VDUP2:
case OP_VDUP3: {
EVT EltTy = VT.getVectorElementType();
unsigned Opcode;
if (EltTy == MVT::i8)
Opcode = AArch64ISD::DUPLANE8;
else if (EltTy == MVT::i16 || EltTy == MVT::f16 || EltTy == MVT::bf16)
Opcode = AArch64ISD::DUPLANE16;
else if (EltTy == MVT::i32 || EltTy == MVT::f32)
Opcode = AArch64ISD::DUPLANE32;
else if (EltTy == MVT::i64 || EltTy == MVT::f64)
Opcode = AArch64ISD::DUPLANE64;
else
llvm_unreachable("Invalid vector element type?");

if (VT.getSizeInBits() == 64)
OpLHS = WidenVector(OpLHS, DAG);
SDValue Lane = DAG.getConstant(OpNum - OP_VDUP0, dl, MVT::i64);
return DAG.getNode(Opcode, dl, VT, OpLHS, Lane);
}
case OP_VEXT1:
case OP_VEXT2:
case OP_VEXT3: {
unsigned Imm = (OpNum - OP_VEXT1 + 1) * getExtFactor(OpLHS);
return DAG.getNode(AArch64ISD::EXT, dl, VT, OpLHS, OpRHS,
DAG.getConstant(Imm, dl, MVT::i32));
}
case OP_VUZPL:
return DAG.getNode(AArch64ISD::UZP1, dl, VT, OpLHS, OpRHS);
case OP_VUZPR:
return DAG.getNode(AArch64ISD::UZP2, dl, VT, OpLHS, OpRHS);
case OP_VZIPL:
return DAG.getNode(AArch64ISD::ZIP1, dl, VT, OpLHS, OpRHS);
case OP_VZIPR:
return DAG.getNode(AArch64ISD::ZIP2, dl, VT, OpLHS, OpRHS);
case OP_VTRNL:
return DAG.getNode(AArch64ISD::TRN1, dl, VT, OpLHS, OpRHS);
case OP_VTRNR:
return DAG.getNode(AArch64ISD::TRN2, dl, VT, OpLHS, OpRHS);
}
}

static SDValue GenerateTBL(SDValue Op, ArrayRef<int> ShuffleMask,
SelectionDAG &DAG) {
// Check to see if we can use the TBL instruction.
Expand Down Expand Up @@ -14096,8 +13930,95 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
unsigned PFTableIndex = PFIndexes[0] * 9 * 9 * 9 + PFIndexes[1] * 9 * 9 +
PFIndexes[2] * 9 + PFIndexes[3];
unsigned PFEntry = PerfectShuffleTable[PFTableIndex];
return GeneratePerfectShuffle(PFTableIndex, V1, V2, PFEntry, V1, V2, DAG,
dl);

auto BuildRev = [&DAG, &dl](SDValue OpLHS) {
EVT VT = OpLHS.getValueType();
unsigned Opcode = VT.getScalarSizeInBits() == 32 ? AArch64ISD::REV64
: VT.getScalarSizeInBits() == 16 ? AArch64ISD::REV32
: AArch64ISD::REV16;
return DAG.getNode(Opcode, dl, VT, OpLHS);
};
auto BuildDup = [&DAG, &dl](SDValue OpLHS, unsigned Lane) {
EVT VT = OpLHS.getValueType();
unsigned Opcode;
if (VT.getScalarSizeInBits() == 8)
Opcode = AArch64ISD::DUPLANE8;
else if (VT.getScalarSizeInBits() == 16)
Opcode = AArch64ISD::DUPLANE16;
else if (VT.getScalarSizeInBits() == 32)
Opcode = AArch64ISD::DUPLANE32;
else if (VT.getScalarSizeInBits() == 64)
Opcode = AArch64ISD::DUPLANE64;
else
llvm_unreachable("Invalid vector element type?");

if (VT.getSizeInBits() == 64)
OpLHS = WidenVector(OpLHS, DAG);
return DAG.getNode(Opcode, dl, VT, OpLHS,
DAG.getConstant(Lane, dl, MVT::i64));
};
auto BuildExt = [&DAG, &dl](SDValue OpLHS, SDValue OpRHS, unsigned Imm) {
EVT VT = OpLHS.getValueType();
Imm = Imm * getExtFactor(OpLHS);
return DAG.getNode(AArch64ISD::EXT, dl, VT, OpLHS, OpRHS,
DAG.getConstant(Imm, dl, MVT::i32));
};
auto BuildZipLike = [&DAG, &dl](unsigned OpNum, SDValue OpLHS,
SDValue OpRHS) {
EVT VT = OpLHS.getValueType();
switch (OpNum) {
default:
llvm_unreachable("Unexpected perfect shuffle opcode\n");
case OP_VUZPL:
return DAG.getNode(AArch64ISD::UZP1, dl, VT, OpLHS, OpRHS);
case OP_VUZPR:
return DAG.getNode(AArch64ISD::UZP2, dl, VT, OpLHS, OpRHS);
case OP_VZIPL:
return DAG.getNode(AArch64ISD::ZIP1, dl, VT, OpLHS, OpRHS);
case OP_VZIPR:
return DAG.getNode(AArch64ISD::ZIP2, dl, VT, OpLHS, OpRHS);
case OP_VTRNL:
return DAG.getNode(AArch64ISD::TRN1, dl, VT, OpLHS, OpRHS);
case OP_VTRNR:
return DAG.getNode(AArch64ISD::TRN2, dl, VT, OpLHS, OpRHS);
}
};
auto BuildExtractInsert64 = [&DAG, &dl](SDValue ExtSrc, unsigned ExtLane,
SDValue InsSrc, unsigned InsLane) {
EVT VT = InsSrc.getValueType();
if (VT.getScalarSizeInBits() == 16) {
ExtSrc = DAG.getBitcast(MVT::v2f32, ExtSrc);
InsSrc = DAG.getBitcast(MVT::v2f32, InsSrc);
} else if (VT.getScalarSizeInBits() == 32) {
ExtSrc = DAG.getBitcast(MVT::v2f64, ExtSrc);
InsSrc = DAG.getBitcast(MVT::v2f64, InsSrc);
}
SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
ExtSrc.getValueType().getVectorElementType(),
ExtSrc, DAG.getVectorIdxConstant(ExtLane, dl));
SDValue Ins =
DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ExtSrc.getValueType(), InsSrc,
Ext, DAG.getVectorIdxConstant(InsLane, dl));
return DAG.getBitcast(VT, Ins);
};
auto BuildExtractInsert32 = [&DAG, &dl](SDValue ExtSrc, unsigned ExtLane,
SDValue InsSrc, unsigned InsLane) {
EVT VT = InsSrc.getValueType();
if (VT.getScalarSizeInBits() == 16) {
ExtSrc = DAG.getBitcast(MVT::v4f16, ExtSrc);
InsSrc = DAG.getBitcast(MVT::v4f16, InsSrc);
}
SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
ExtSrc.getValueType().getVectorElementType(),
ExtSrc, DAG.getVectorIdxConstant(ExtLane, dl));
SDValue Ins =
DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ExtSrc.getValueType(), InsSrc,
Ext, DAG.getVectorIdxConstant(InsLane, dl));
return DAG.getBitcast(VT, Ins);
};
return generatePerfectShuffle<SDValue, MVT>(
PFTableIndex, V1, V2, PFEntry, V1, V2, BuildExtractInsert64,
BuildExtractInsert32, BuildRev, BuildDup, BuildExt, BuildZipLike);
}

// Check for a "select shuffle", generating a BSL to pick between lanes in
Expand Down
Loading