-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[LoongArch] Lower vector select mask generation to [X]VMSK{LT,GE,NE}Z
if possible
#142109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…X]VMSK{LT,GE,NE}Z`
@llvm/pr-subscribers-backend-loongarch Author: hev (heiher) ChangesThis patch adds a DAG combine rule for BITCAST nodes converting from vector
When the vector comparison matches specific patterns (e.g., x < 0, x >= 0, x != 0, etc.), the transformation is performed pre-legalization. This avoids scalarization and unnecessary operations, improving both performance and code size. Patch is 27.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142109.diff 6 Files Affected:
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index 7a9ec9f5e96b3..10e33734138bf 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -388,8 +388,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
// Set DAG combine for 'LSX' feature.
- if (Subtarget.hasExtLSX())
+ if (Subtarget.hasExtLSX()) {
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
+ setTargetDAGCombine(ISD::BITCAST);
+ }
// Compute derived properties from the register classes.
computeRegisterProperties(Subtarget.getRegisterInfo());
@@ -4286,6 +4288,94 @@ static SDValue performSRLCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
}
+static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const LoongArchSubtarget &Subtarget) {
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+ SDValue Src = N->getOperand(0);
+ EVT SrcVT = Src.getValueType();
+
+ if (!DCI.isBeforeLegalizeOps())
+ return SDValue();
+
+ if (!SrcVT.isSimple() || SrcVT.getScalarType() != MVT::i1)
+ return SDValue();
+
+ if (Src.getOpcode() != ISD::SETCC || !Src.hasOneUse())
+ return SDValue();
+
+ bool UseLASX;
+ EVT CmpVT = Src.getOperand(0).getValueType();
+ EVT EltVT = CmpVT.getVectorElementType();
+ if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() <= 128)
+ UseLASX = false;
+ else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
+ CmpVT.getSizeInBits() <= 256)
+ UseLASX = true;
+ else
+ return SDValue();
+
+ unsigned ISD = ISD::DELETED_NODE;
+ SDValue SrcN1 = Src.getOperand(1);
+ switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
+ default:
+ return SDValue();
+ case ISD::SETEQ:
+ if (EltVT == MVT::i8) {
+ // x == 0 => not (vmsknez.b x)
+ if (ISD::isBuildVectorAllZeros(SrcN1.getNode()))
+ ISD = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
+ // x == -1 => vmsknez.b x
+ else if (ISD::isBuildVectorAllOnes(SrcN1.getNode()))
+ ISD = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
+ }
+ break;
+ case ISD::SETGT:
+ // x > -1 => vmskgez.b x
+ if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
+ ISD = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+ break;
+ case ISD::SETGE:
+ // x >= 0 => vmskgez.b x
+ if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+ ISD = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+ break;
+ case ISD::SETLT:
+ // x < 0 => vmskltz.{b,h,w,d} x
+ if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+ EltVT == MVT::i64))
+ ISD = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+ break;
+ case ISD::SETLE:
+ // x <= -1 => vmskltz.{b,h,w,d} x
+ if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+ EltVT == MVT::i64))
+ ISD = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+ break;
+ case ISD::SETNE:
+ if (EltVT == MVT::i8) {
+ // x != 0 => vmsknez.b x
+ if (ISD::isBuildVectorAllZeros(SrcN1.getNode()))
+ ISD = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
+ // x != -1 => not (vmsknez.b x)
+ else if (ISD::isBuildVectorAllOnes(SrcN1.getNode()))
+ ISD = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
+ }
+ break;
+ }
+
+ if (ISD == ISD::DELETED_NODE)
+ return SDValue();
+
+ SDValue V = DAG.getNode(ISD, DL, MVT::i64, Src.getOperand(0));
+ EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
+ V = DAG.getZExtOrTrunc(V, DL, T);
+ return DAG.getBitcast(VT, V);
+}
+
static SDValue performORCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const LoongArchSubtarget &Subtarget) {
@@ -5303,6 +5393,8 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
return performSETCCCombine(N, DAG, DCI, Subtarget);
case ISD::SRL:
return performSRLCombine(N, DAG, DCI, Subtarget);
+ case ISD::BITCAST:
+ return performBITCASTCombine(N, DAG, DCI, Subtarget);
case LoongArchISD::BITREV_W:
return performBITREV_WCombine(N, DAG, DCI, Subtarget);
case ISD::INTRINSIC_WO_CHAIN:
@@ -5589,6 +5681,120 @@ static MachineBasicBlock *emitPseudoCTPOP(MachineInstr &MI,
return BB;
}
+static MachineBasicBlock *
+emitPseudoVMSKCOND(MachineInstr &MI, MachineBasicBlock *BB,
+ const LoongArchSubtarget &Subtarget) {
+ const TargetInstrInfo *TII = Subtarget.getInstrInfo();
+ const TargetRegisterClass *RC = &LoongArch::LSX128RegClass;
+ const LoongArchRegisterInfo *TRI = Subtarget.getRegisterInfo();
+ MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
+ Register Dst = MI.getOperand(0).getReg();
+ Register Src = MI.getOperand(1).getReg();
+ DebugLoc DL = MI.getDebugLoc();
+ unsigned EleBits = 8;
+ unsigned NotOpc = 0;
+ unsigned MskOpc;
+
+ switch (MI.getOpcode()) {
+ default:
+ llvm_unreachable("Unexpected opcode");
+ case LoongArch::PseudoVMSKLTZ_B:
+ MskOpc = LoongArch::VMSKLTZ_B;
+ break;
+ case LoongArch::PseudoVMSKLTZ_H:
+ MskOpc = LoongArch::VMSKLTZ_H;
+ EleBits = 16;
+ break;
+ case LoongArch::PseudoVMSKLTZ_W:
+ MskOpc = LoongArch::VMSKLTZ_W;
+ EleBits = 32;
+ break;
+ case LoongArch::PseudoVMSKLTZ_D:
+ MskOpc = LoongArch::VMSKLTZ_D;
+ EleBits = 64;
+ break;
+ case LoongArch::PseudoVMSKGEZ_B:
+ MskOpc = LoongArch::VMSKGEZ_B;
+ break;
+ case LoongArch::PseudoVMSKEQZ_B:
+ MskOpc = LoongArch::VMSKNZ_B;
+ NotOpc = LoongArch::VNOR_V;
+ break;
+ case LoongArch::PseudoVMSKNEZ_B:
+ MskOpc = LoongArch::VMSKNZ_B;
+ break;
+ case LoongArch::PseudoXVMSKLTZ_B:
+ MskOpc = LoongArch::XVMSKLTZ_B;
+ RC = &LoongArch::LASX256RegClass;
+ break;
+ case LoongArch::PseudoXVMSKLTZ_H:
+ MskOpc = LoongArch::XVMSKLTZ_H;
+ RC = &LoongArch::LASX256RegClass;
+ EleBits = 16;
+ break;
+ case LoongArch::PseudoXVMSKLTZ_W:
+ MskOpc = LoongArch::XVMSKLTZ_W;
+ RC = &LoongArch::LASX256RegClass;
+ EleBits = 32;
+ break;
+ case LoongArch::PseudoXVMSKLTZ_D:
+ MskOpc = LoongArch::XVMSKLTZ_D;
+ RC = &LoongArch::LASX256RegClass;
+ EleBits = 64;
+ break;
+ case LoongArch::PseudoXVMSKGEZ_B:
+ MskOpc = LoongArch::XVMSKGEZ_B;
+ RC = &LoongArch::LASX256RegClass;
+ break;
+ case LoongArch::PseudoXVMSKEQZ_B:
+ MskOpc = LoongArch::XVMSKNZ_B;
+ NotOpc = LoongArch::XVNOR_V;
+ RC = &LoongArch::LASX256RegClass;
+ break;
+ case LoongArch::PseudoXVMSKNEZ_B:
+ MskOpc = LoongArch::XVMSKNZ_B;
+ RC = &LoongArch::LASX256RegClass;
+ break;
+ }
+
+ Register Msk = MRI.createVirtualRegister(RC);
+ if (NotOpc) {
+ Register Tmp = MRI.createVirtualRegister(RC);
+ BuildMI(*BB, MI, DL, TII->get(MskOpc), Tmp).addReg(Src);
+ BuildMI(*BB, MI, DL, TII->get(NotOpc), Msk)
+ .addReg(Tmp, RegState::Kill)
+ .addReg(Tmp, RegState::Kill);
+ } else {
+ BuildMI(*BB, MI, DL, TII->get(MskOpc), Msk).addReg(Src);
+ }
+
+ if (TRI->getRegSizeInBits(*RC) > 128) {
+ Register Lo = MRI.createVirtualRegister(&LoongArch::GPRRegClass);
+ Register Hi = MRI.createVirtualRegister(&LoongArch::GPRRegClass);
+ BuildMI(*BB, MI, DL, TII->get(LoongArch::XVPICKVE2GR_WU), Lo)
+ .addReg(Msk, RegState::Kill)
+ .addImm(0);
+ BuildMI(*BB, MI, DL, TII->get(LoongArch::XVPICKVE2GR_WU), Hi)
+ .addReg(Msk, RegState::Kill)
+ .addImm(4);
+ BuildMI(*BB, MI, DL,
+ TII->get(Subtarget.is64Bit() ? LoongArch::BSTRINS_D
+ : LoongArch::BSTRINS_W),
+ Dst)
+ .addReg(Lo, RegState::Kill)
+ .addReg(Hi, RegState::Kill)
+ .addImm(256 / EleBits - 1)
+ .addImm(128 / EleBits);
+ } else {
+ BuildMI(*BB, MI, DL, TII->get(LoongArch::VPICKVE2GR_HU), Dst)
+ .addReg(Msk, RegState::Kill)
+ .addImm(0);
+ }
+
+ MI.eraseFromParent();
+ return BB;
+}
+
static bool isSelectPseudo(MachineInstr &MI) {
switch (MI.getOpcode()) {
default:
@@ -5795,6 +6001,21 @@ MachineBasicBlock *LoongArchTargetLowering::EmitInstrWithCustomInserter(
return emitPseudoXVINSGR2VR(MI, BB, Subtarget);
case LoongArch::PseudoCTPOP:
return emitPseudoCTPOP(MI, BB, Subtarget);
+ case LoongArch::PseudoVMSKLTZ_B:
+ case LoongArch::PseudoVMSKLTZ_H:
+ case LoongArch::PseudoVMSKLTZ_W:
+ case LoongArch::PseudoVMSKLTZ_D:
+ case LoongArch::PseudoVMSKGEZ_B:
+ case LoongArch::PseudoVMSKEQZ_B:
+ case LoongArch::PseudoVMSKNEZ_B:
+ case LoongArch::PseudoXVMSKLTZ_B:
+ case LoongArch::PseudoXVMSKLTZ_H:
+ case LoongArch::PseudoXVMSKLTZ_W:
+ case LoongArch::PseudoXVMSKLTZ_D:
+ case LoongArch::PseudoXVMSKGEZ_B:
+ case LoongArch::PseudoXVMSKEQZ_B:
+ case LoongArch::PseudoXVMSKNEZ_B:
+ return emitPseudoVMSKCOND(MI, BB, Subtarget);
case TargetOpcode::STATEPOINT:
// STATEPOINT is a pseudo instruction which has no implicit defs/uses
// while bl call instruction (where statepoint will be lowered at the
@@ -5916,6 +6137,14 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VBSLL)
NODE_NAME_CASE(VBSRL)
NODE_NAME_CASE(VLDREPL)
+ NODE_NAME_CASE(VMSKLTZ)
+ NODE_NAME_CASE(VMSKGEZ)
+ NODE_NAME_CASE(VMSKEQZ)
+ NODE_NAME_CASE(VMSKNEZ)
+ NODE_NAME_CASE(XVMSKLTZ)
+ NODE_NAME_CASE(XVMSKGEZ)
+ NODE_NAME_CASE(XVMSKEQZ)
+ NODE_NAME_CASE(XVMSKNEZ)
}
#undef NODE_NAME_CASE
return nullptr;
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
index 6bf295984dfc5..d22b635c5b187 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
@@ -161,7 +161,17 @@ enum NodeType : unsigned {
VBSRL,
// Scalar load broadcast to vector
- VLDREPL
+ VLDREPL,
+
+ // Vector mask set by condition
+ VMSKLTZ,
+ VMSKGEZ,
+ VMSKEQZ,
+ VMSKNEZ,
+ XVMSKLTZ,
+ XVMSKGEZ,
+ XVMSKEQZ,
+ XVMSKNEZ,
// Intrinsic operations end =============================================
};
diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
index 59dec76ef1c2d..ff7b0f2ae3f25 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
@@ -12,6 +12,10 @@
// Target nodes.
def loongarch_xvpermi: SDNode<"LoongArchISD::XVPERMI", SDT_LoongArchV1RUimm>;
+def loongarch_xvmskltz: SDNode<"LoongArchISD::XVMSKLTZ", SDT_LoongArchVMSKCOND>;
+def loongarch_xvmskgez: SDNode<"LoongArchISD::XVMSKGEZ", SDT_LoongArchVMSKCOND>;
+def loongarch_xvmskeqz: SDNode<"LoongArchISD::XVMSKEQZ", SDT_LoongArchVMSKCOND>;
+def loongarch_xvmsknez: SDNode<"LoongArchISD::XVMSKNEZ", SDT_LoongArchVMSKCOND>;
def lasxsplati8
: PatFrag<(ops node:$e0),
@@ -1086,6 +1090,16 @@ def PseudoXVINSGR2VR_H
: Pseudo<(outs LASX256:$dst), (ins LASX256:$xd, GPR:$rj, uimm4:$imm)>;
} // usesCustomInserter = 1, Constraints = "$xd = $dst"
+let usesCustomInserter = 1, hasSideEffects = 0, mayLoad = 0, mayStore = 0 in {
+def PseudoXVMSKLTZ_B : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
+def PseudoXVMSKLTZ_H : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
+def PseudoXVMSKLTZ_W : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
+def PseudoXVMSKLTZ_D : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
+def PseudoXVMSKGEZ_B : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
+def PseudoXVMSKEQZ_B : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
+def PseudoXVMSKNEZ_B : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
+} // usesCustomInserter = 1, hasSideEffects = 0, mayLoad = 0, mayStore = 0
+
} // Predicates = [HasExtLASX]
multiclass PatXr<SDPatternOperator OpNode, string Inst> {
@@ -1856,6 +1870,15 @@ def : Pat<(vt (concat_vectors LSX128:$vd, LSX128:$vj)),
defm : PatXrXr<abds, "XVABSD">;
defm : PatXrXrU<abdu, "XVABSD">;
+// Vector mask set by condition
+def : Pat<(loongarch_xvmskltz (v32i8 LASX256:$vj)), (PseudoXVMSKLTZ_B LASX256:$vj)>;
+def : Pat<(loongarch_xvmskltz (v16i16 LASX256:$vj)), (PseudoXVMSKLTZ_H LASX256:$vj)>;
+def : Pat<(loongarch_xvmskltz (v8i32 LASX256:$vj)), (PseudoXVMSKLTZ_W LASX256:$vj)>;
+def : Pat<(loongarch_xvmskltz (v4i64 LASX256:$vj)), (PseudoXVMSKLTZ_D LASX256:$vj)>;
+def : Pat<(loongarch_xvmskgez (v32i8 LASX256:$vj)), (PseudoXVMSKGEZ_B LASX256:$vj)>;
+def : Pat<(loongarch_xvmskeqz (v32i8 LASX256:$vj)), (PseudoXVMSKEQZ_B LASX256:$vj)>;
+def : Pat<(loongarch_xvmsknez (v32i8 LASX256:$vj)), (PseudoXVMSKNEZ_B LASX256:$vj)>;
+
} // Predicates = [HasExtLASX]
/// Intrinsic pattern
diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
index b2b3b65155265..d73d78083ddcd 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
@@ -31,6 +31,7 @@ def SDT_LoongArchVreplgr2vr : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<0>, S
def SDT_LoongArchVFRECIPE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
def SDT_LoongArchVFRSQRTE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
def SDT_LoongArchVLDREPL : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisPtrTy<1>]>;
+def SDT_LoongArchVMSKCOND : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<1>]>;
// Target nodes.
def loongarch_vreplve : SDNode<"LoongArchISD::VREPLVE", SDT_LoongArchVreplve>;
@@ -74,6 +75,11 @@ def loongarch_vldrepl
: SDNode<"LoongArchISD::VLDREPL",
SDT_LoongArchVLDREPL, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
+def loongarch_vmskltz: SDNode<"LoongArchISD::VMSKLTZ", SDT_LoongArchVMSKCOND>;
+def loongarch_vmskgez: SDNode<"LoongArchISD::VMSKGEZ", SDT_LoongArchVMSKCOND>;
+def loongarch_vmskeqz: SDNode<"LoongArchISD::VMSKEQZ", SDT_LoongArchVMSKCOND>;
+def loongarch_vmsknez: SDNode<"LoongArchISD::VMSKNEZ", SDT_LoongArchVMSKCOND>;
+
def immZExt1 : ImmLeaf<i64, [{return isUInt<1>(Imm);}]>;
def immZExt2 : ImmLeaf<i64, [{return isUInt<2>(Imm);}]>;
def immZExt3 : ImmLeaf<i64, [{return isUInt<3>(Imm);}]>;
@@ -1266,6 +1272,16 @@ let usesCustomInserter = 1 in
def PseudoCTPOP : Pseudo<(outs GPR:$rd), (ins GPR:$rj),
[(set GPR:$rd, (ctpop GPR:$rj))]>;
+let usesCustomInserter = 1, hasSideEffects = 0, mayLoad = 0, mayStore = 0 in {
+def PseudoVMSKLTZ_B : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
+def PseudoVMSKLTZ_H : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
+def PseudoVMSKLTZ_W : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
+def PseudoVMSKLTZ_D : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
+def PseudoVMSKGEZ_B : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
+def PseudoVMSKEQZ_B : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
+def PseudoVMSKNEZ_B : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
+} // usesCustomInserter = 1, hasSideEffects = 0, mayLoad = 0, mayStore = 0
+
} // Predicates = [HasExtLSX]
multiclass PatVr<SDPatternOperator OpNode, string Inst> {
@@ -2050,6 +2066,15 @@ def : Pat<(f64 f64imm_vldi:$in),
defm : PatVrVr<abds, "VABSD">;
defm : PatVrVrU<abdu, "VABSD">;
+// Vector mask set by condition
+def : Pat<(loongarch_vmskltz (v16i8 LSX128:$vj)), (PseudoVMSKLTZ_B LSX128:$vj)>;
+def : Pat<(loongarch_vmskltz (v8i16 LSX128:$vj)), (PseudoVMSKLTZ_H LSX128:$vj)>;
+def : Pat<(loongarch_vmskltz (v4i32 LSX128:$vj)), (PseudoVMSKLTZ_W LSX128:$vj)>;
+def : Pat<(loongarch_vmskltz (v2i64 LSX128:$vj)), (PseudoVMSKLTZ_D LSX128:$vj)>;
+def : Pat<(loongarch_vmskgez (v16i8 LSX128:$vj)), (PseudoVMSKGEZ_B LSX128:$vj)>;
+def : Pat<(loongarch_vmskeqz (v16i8 LSX128:$vj)), (PseudoVMSKEQZ_B LSX128:$vj)>;
+def : Pat<(loongarch_vmsknez (v16i8 LSX128:$vj)), (PseudoVMSKNEZ_B LSX128:$vj)>;
+
} // Predicates = [HasExtLSX]
/// Intrinsic pattern
diff --git a/llvm/test/CodeGen/LoongArch/lasx/xvmskcond.ll b/llvm/test/CodeGen/LoongArch/lasx/xvmskcond.ll
new file mode 100644
index 0000000000000..93b081e450985
--- /dev/null
+++ b/llvm/test/CodeGen/LoongArch/lasx/xvmskcond.ll
@@ -0,0 +1,200 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc --mtriple=loongarch64 --mattr=+lasx < %s | FileCheck %s
+
+define i32 @xvmsk_eq_allzeros_i8(<32 x i8 > %a) {
+; CHECK-LABEL: xvmsk_eq_allzeros_i8:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: xvmsknz.b $xr0, $xr0
+; CHECK-NEXT: xvnor.v $xr0, $xr0, $xr0
+; CHECK-NEXT: xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT: xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT: bstrins.d $a0, $a1, 31, 16
+; CHECK-NEXT: ret
+entry:
+ %1 = icmp eq <32 x i8> %a, splat (i8 0)
+ %2 = bitcast <32 x i1> %1 to i32
+ ret i32 %2
+}
+
+define i32 @xvmsk_eq_allones_i8(<32 x i8 > %a) {
+; CHECK-LABEL: xvmsk_eq_allones_i8:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: xvmsknz.b $xr0, $xr0
+; CHECK-NEXT: xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT: xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT: bstrins.d $a0, $a1, 31, 16
+; CHECK-NEXT: ret
+entry:
+ %1 = icmp eq <32 x i8> %a, splat (i8 -1)
+ %2 = bitcast <32 x i1> %1 to i32
+ ret i32 %2
+}
+
+define i32 @xvmsk_sgt_allones_i8(<32 x i8 > %a) {
+; CHECK-LABEL: xvmsk_sgt_allones_i8:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: xvmskgez.b $xr0, $xr0
+; CHECK-NEXT: xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT: xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT: bstrins.d $a0, $a1, 31, 16
+; CHECK-NEXT: ret
+entry:
+ %1 = icmp sgt <32 x i8> %a, splat (i8 -1)
+ %2 = bitcast <32 x i1> %1 to i32
+ ret i32 %2
+}
+
+define i32 @xvmsk_sge_allzeros_i8(<32 x i8 > %a) {
+; CHECK-LABEL: xvmsk_sge_allzeros_i8:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: xvmskgez.b $xr0, $xr0
+; CHECK-NEXT: xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT: xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT: bstrins.d $a0, $a1, 31, 16
+; CHECK-NEXT: ret
+entry:
+ %1 = icmp sge <32 x i8> %a, splat (i8 0)
+ %2 = bitcast <32 x i1> %1 to i32
+ ret i32 %2
+}
+
+define i32 @xvmsk_slt_allzeros_i8(<32 x i8 > %a) {
+; CHECK-LABEL: xvmsk_slt_allzeros_i8:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: xvmskltz.b $xr0, $xr0
+; CHECK-NEXT: xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT: xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT: bstrins.d $a0, $a1, 31, 16
+; CHECK-NEXT: ret
+entry:
+ %1 = icmp slt <32 x i8> %a, splat (i8 0)
+ %2 = bitcast <32 x i1> %1 to i32
+ ret i32 %2
+}
+
+define i16 @xvmsk_slt_allzeros_i16(<16 x i16 > %a) {
+; CHECK-LABEL: xvmsk_slt_allzeros_i16:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: xvmskltz.h $xr0, $xr0
+; CHECK-NEXT: xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT: xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT: bstrins.d $a0, $a1, 15, 8
+; CHECK-NEXT: ret
+entry:
+ %1 = icmp slt <16 x i16> %a, splat (i16 0)
+ %2 = bitcast <16 x i1> %1 to i16
+ ret i16 %2
+}
+
+define i8 @xvmsk_slt_allzeros_i32(<8 x i32 > %a) {
+; CHECK-LABEL: xvmsk_slt_allzeros_i32:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: xvmskltz.w $xr0, $xr0
+; CHECK-NEXT: xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT: xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT: bstrins.d $a0, $a1, 7, 4
+; CHECK-NEXT: ret
+entry:
+ %1 = icmp slt <8 x i32> %a, splat (i32 0)
+ %2 = bitcast <8 x i1> %1 to i8
+ ret i8 %2
+}
+
+define i4 @xvmsk_slt_allzeros_i64(<4 x i64 > %a) {
+; CHECK-LABEL: xvmsk_slt_allzeros_i64:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: xvmskltz.d $xr0, $xr0
+; CHECK-NEXT: xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT: xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT: bstrins.d $a0, $a1, 3, 2
+; CHECK-NEXT: ret
+entry:
+ %1 = icmp slt <4 x i64> %a, splat (i64 0)
+ %2 = bitcast <4 x i1> %1 to i4
+ ret i4 %2
+}
+
+define i32 @xvmsk_sle_allones_i8(<32 x i8 > %a) {
+; CHECK-LABEL: xvmsk_sle_allones_i8:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: xvmskltz.b $xr0, $xr0
+; CHECK-NEXT: xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT: xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT: bstrins.d $a0, $a1, 31, 16
+; CHECK-NEXT: ret
+entry:
+ %1 = icmp sle <32 x i8> %a, splat (i8 -1)
+ %2 = bit...
[truncated]
|
For byte we have This patch seems to have no impact on the llvm-test-suite |
…Z` if possible This patch adds a DAG combine rule for BITCAST nodes converting from vector `i1` masks generated by `setcc` into integer vector types. It recognizes common select mask patterns and lowers them into efficient LoongArch LSX/LASX mask instructions such as: - [X]VMSKLTZ.{B,H,W,D} - [X]VMSKGEZ.B - [X]VMSKNEZ.B When the vector comparison matches specific patterns (e.g., x < 0, x >= 0, x != 0, etc.), the transformation is performed pre-legalization. This avoids scalarization and unnecessary operations, improving both performance and code size.
We also have not (x <= 0) vslei.b t, x, 0
vmskgez.b t, t |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. @wangleiat What do you think?
This patch adds a DAG combine rule for BITCAST nodes converting from vector
i1
masks generated bysetcc
into integer vector types. It recognizes common select mask patterns and lowers them into efficient LoongArch LSX/LASX mask instructions such as:When the vector comparison matches specific patterns (e.g., x < 0, x >= 0, x != 0, etc.), the transformation is performed pre-legalization. This avoids scalarization and unnecessary operations, improving both performance and code size.