Skip to content

Commit 1aac660

Browse files
zhaoqi5mahesh-attarde
authored andcommitted
[LoongArch] Custom legalize vector_shuffle to xvinsve0.{w/d} when possible (llvm#161156)
1 parent 8f1308b commit 1aac660

File tree

4 files changed

+107
-202
lines changed

4 files changed

+107
-202
lines changed

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2319,6 +2319,53 @@ static SDValue lowerVECTOR_SHUFFLE_XVPICKOD(const SDLoc &DL, ArrayRef<int> Mask,
23192319
return DAG.getNode(LoongArchISD::VPICKOD, DL, VT, V2, V1);
23202320
}
23212321

2322+
/// Lower VECTOR_SHUFFLE into XVINSVE0 (if possible).
2323+
static SDValue
2324+
lowerVECTOR_SHUFFLE_XVINSVE0(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
2325+
SDValue V1, SDValue V2, SelectionDAG &DAG,
2326+
const LoongArchSubtarget &Subtarget) {
2327+
// LoongArch LASX only supports xvinsve0.{w/d}.
2328+
if (VT != MVT::v8i32 && VT != MVT::v8f32 && VT != MVT::v4i64 &&
2329+
VT != MVT::v4f64)
2330+
return SDValue();
2331+
2332+
MVT GRLenVT = Subtarget.getGRLenVT();
2333+
int MaskSize = Mask.size();
2334+
assert(MaskSize == (int)VT.getVectorNumElements() && "Unexpected mask size");
2335+
2336+
// Check if exactly one element of the Mask is replaced by 'Replaced', while
2337+
// all other elements are either 'Base + i' or undef (-1). On success, return
2338+
// the index of the replaced element. Otherwise, just return -1.
2339+
auto checkReplaceOne = [&](int Base, int Replaced) -> int {
2340+
int Idx = -1;
2341+
for (int i = 0; i < MaskSize; ++i) {
2342+
if (Mask[i] == Base + i || Mask[i] == -1)
2343+
continue;
2344+
if (Mask[i] != Replaced)
2345+
return -1;
2346+
if (Idx == -1)
2347+
Idx = i;
2348+
else
2349+
return -1;
2350+
}
2351+
return Idx;
2352+
};
2353+
2354+
// Case 1: the lowest element of V2 replaces one element in V1.
2355+
int Idx = checkReplaceOne(0, MaskSize);
2356+
if (Idx != -1)
2357+
return DAG.getNode(LoongArchISD::XVINSVE0, DL, VT, V1, V2,
2358+
DAG.getConstant(Idx, DL, GRLenVT));
2359+
2360+
// Case 2: the lowest element of V1 replaces one element in V2.
2361+
Idx = checkReplaceOne(MaskSize, 0);
2362+
if (Idx != -1)
2363+
return DAG.getNode(LoongArchISD::XVINSVE0, DL, VT, V2, V1,
2364+
DAG.getConstant(Idx, DL, GRLenVT));
2365+
2366+
return SDValue();
2367+
}
2368+
23222369
/// Lower VECTOR_SHUFFLE into XVSHUF (if possible).
23232370
static SDValue lowerVECTOR_SHUFFLE_XVSHUF(const SDLoc &DL, ArrayRef<int> Mask,
23242371
MVT VT, SDValue V1, SDValue V2,
@@ -2595,6 +2642,9 @@ static SDValue lower256BitShuffle(const SDLoc &DL, ArrayRef<int> Mask, MVT VT,
25952642
if ((Result = lowerVECTOR_SHUFFLEAsShift(DL, Mask, VT, V1, V2, DAG, Subtarget,
25962643
Zeroable)))
25972644
return Result;
2645+
if ((Result =
2646+
lowerVECTOR_SHUFFLE_XVINSVE0(DL, Mask, VT, V1, V2, DAG, Subtarget)))
2647+
return Result;
25982648
if ((Result = lowerVECTOR_SHUFFLEAsByteRotate(DL, Mask, VT, V1, V2, DAG,
25992649
Subtarget)))
26002650
return Result;
@@ -7453,6 +7503,7 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
74537503
NODE_NAME_CASE(XVPERM)
74547504
NODE_NAME_CASE(XVREPLVE0)
74557505
NODE_NAME_CASE(XVREPLVE0Q)
7506+
NODE_NAME_CASE(XVINSVE0)
74567507
NODE_NAME_CASE(VPICK_SEXT_ELT)
74577508
NODE_NAME_CASE(VPICK_ZEXT_ELT)
74587509
NODE_NAME_CASE(VREPLVE)

llvm/lib/Target/LoongArch/LoongArchISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ enum NodeType : unsigned {
151151
XVPERM,
152152
XVREPLVE0,
153153
XVREPLVE0Q,
154+
XVINSVE0,
154155

155156
// Extended vector element extraction
156157
VPICK_SEXT_ELT,

llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def loongarch_xvpermi: SDNode<"LoongArchISD::XVPERMI", SDT_LoongArchV1RUimm>;
2020
def loongarch_xvperm: SDNode<"LoongArchISD::XVPERM", SDT_LoongArchXVPERM>;
2121
def loongarch_xvreplve0: SDNode<"LoongArchISD::XVREPLVE0", SDT_LoongArchXVREPLVE0>;
2222
def loongarch_xvreplve0q: SDNode<"LoongArchISD::XVREPLVE0Q", SDT_LoongArchXVREPLVE0>;
23+
def loongarch_xvinsve0 : SDNode<"LoongArchISD::XVINSVE0", SDT_LoongArchV2RUimm>;
2324
def loongarch_xvmskltz: SDNode<"LoongArchISD::XVMSKLTZ", SDT_LoongArchVMSKCOND>;
2425
def loongarch_xvmskgez: SDNode<"LoongArchISD::XVMSKGEZ", SDT_LoongArchVMSKCOND>;
2526
def loongarch_xvmskeqz: SDNode<"LoongArchISD::XVMSKEQZ", SDT_LoongArchVMSKCOND>;
@@ -1708,6 +1709,14 @@ def : Pat<(vector_insert v4f64:$xd, (f64(bitconvert i64:$rj)), uimm2:$imm),
17081709
(XVINSGR2VR_D v4f64:$xd, GPR:$rj, uimm2:$imm)>;
17091710

17101711
// XVINSVE0_{W/D}
1712+
def : Pat<(loongarch_xvinsve0 v8i32:$xd, v8i32:$xj, uimm3:$imm),
1713+
(XVINSVE0_W v8i32:$xd, v8i32:$xj, uimm3:$imm)>;
1714+
def : Pat<(loongarch_xvinsve0 v4i64:$xd, v4i64:$xj, uimm2:$imm),
1715+
(XVINSVE0_D v4i64:$xd, v4i64:$xj, uimm2:$imm)>;
1716+
def : Pat<(loongarch_xvinsve0 v8f32:$xd, v8f32:$xj, uimm3:$imm),
1717+
(XVINSVE0_W v8f32:$xd, v8f32:$xj, uimm3:$imm)>;
1718+
def : Pat<(loongarch_xvinsve0 v4f64:$xd, v4f64:$xj, uimm2:$imm),
1719+
(XVINSVE0_D v4f64:$xd, v4f64:$xj, uimm2:$imm)>;
17111720
def : Pat<(vector_insert v8f32:$xd, FPR32:$fj, uimm3:$imm),
17121721
(XVINSVE0_W v8f32:$xd, (SUBREG_TO_REG(i64 0), FPR32:$fj, sub_32),
17131722
uimm3:$imm)>;

0 commit comments

Comments
 (0)