Skip to content

[LoongArch] Lower vector select mask generation to [X]VMSK{LT,GE,NE}Z if possible #142109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 5, 2025

Conversation

heiher
Copy link
Member

@heiher heiher commented May 30, 2025

This patch adds a DAG combine rule for BITCAST nodes converting from vector i1 masks generated by setcc into integer vector types. It recognizes common select mask patterns and lowers them into efficient LoongArch LSX/LASX mask instructions such as:

  • [X]VMSKLTZ.{B,H,W,D}
  • [X]VMSKGEZ.B
  • [X]VMSKNEZ.B

When the vector comparison matches specific patterns (e.g., x < 0, x >= 0, x != 0, etc.), the transformation is performed pre-legalization. This avoids scalarization and unnecessary operations, improving both performance and code size.

@llvmbot
Copy link
Member

llvmbot commented May 30, 2025

@llvm/pr-subscribers-backend-loongarch

Author: hev (heiher)

Changes

This patch adds a DAG combine rule for BITCAST nodes converting from vector i1 masks generated by setcc into integer vector types. It recognizes common select mask patterns and lowers them into efficient LoongArch LSX/LASX mask instructions such as:

  • [X]VMSKLTZ.{B,H,W,D}
  • [X]VMSKGEZ.B
  • [X]VMSKNEZ.B

When the vector comparison matches specific patterns (e.g., x < 0, x >= 0, x != 0, etc.), the transformation is performed pre-legalization. This avoids scalarization and unnecessary operations, improving both performance and code size.


Patch is 27.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/142109.diff

6 Files Affected:

  • (modified) llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp (+230-1)
  • (modified) llvm/lib/Target/LoongArch/LoongArchISelLowering.h (+11-1)
  • (modified) llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td (+23)
  • (modified) llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td (+25)
  • (added) llvm/test/CodeGen/LoongArch/lasx/xvmskcond.ll (+200)
  • (added) llvm/test/CodeGen/LoongArch/lsx/vmskcond.ll (+172)
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index 7a9ec9f5e96b3..10e33734138bf 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -388,8 +388,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
 
   // Set DAG combine for 'LSX' feature.
 
-  if (Subtarget.hasExtLSX())
+  if (Subtarget.hasExtLSX()) {
     setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
+    setTargetDAGCombine(ISD::BITCAST);
+  }
 
   // Compute derived properties from the register classes.
   computeRegisterProperties(Subtarget.getRegisterInfo());
@@ -4286,6 +4288,94 @@ static SDValue performSRLCombine(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
+static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
+                                     TargetLowering::DAGCombinerInfo &DCI,
+                                     const LoongArchSubtarget &Subtarget) {
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+  SDValue Src = N->getOperand(0);
+  EVT SrcVT = Src.getValueType();
+
+  if (!DCI.isBeforeLegalizeOps())
+    return SDValue();
+
+  if (!SrcVT.isSimple() || SrcVT.getScalarType() != MVT::i1)
+    return SDValue();
+
+  if (Src.getOpcode() != ISD::SETCC || !Src.hasOneUse())
+    return SDValue();
+
+  bool UseLASX;
+  EVT CmpVT = Src.getOperand(0).getValueType();
+  EVT EltVT = CmpVT.getVectorElementType();
+  if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() <= 128)
+    UseLASX = false;
+  else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
+           CmpVT.getSizeInBits() <= 256)
+    UseLASX = true;
+  else
+    return SDValue();
+
+  unsigned ISD = ISD::DELETED_NODE;
+  SDValue SrcN1 = Src.getOperand(1);
+  switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
+  default:
+    return SDValue();
+  case ISD::SETEQ:
+    if (EltVT == MVT::i8) {
+      // x == 0 => not (vmsknez.b x)
+      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()))
+        ISD = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
+      // x == -1 => vmsknez.b x
+      else if (ISD::isBuildVectorAllOnes(SrcN1.getNode()))
+        ISD = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
+    }
+    break;
+  case ISD::SETGT:
+    // x > -1 => vmskgez.b x
+    if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
+      ISD = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+    break;
+  case ISD::SETGE:
+    // x >= 0 => vmskgez.b x
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+      ISD = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+    break;
+  case ISD::SETLT:
+    // x < 0 => vmskltz.{b,h,w,d} x
+    if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
+        (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+         EltVT == MVT::i64))
+      ISD = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+    break;
+  case ISD::SETLE:
+    // x <= -1 => vmskltz.{b,h,w,d} x
+    if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
+        (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+         EltVT == MVT::i64))
+      ISD = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+    break;
+  case ISD::SETNE:
+    if (EltVT == MVT::i8) {
+      // x != 0 => vmsknez.b x
+      if (ISD::isBuildVectorAllZeros(SrcN1.getNode()))
+        ISD = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
+      // x != -1 => not (vmsknez.b x)
+      else if (ISD::isBuildVectorAllOnes(SrcN1.getNode()))
+        ISD = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
+    }
+    break;
+  }
+
+  if (ISD == ISD::DELETED_NODE)
+    return SDValue();
+
+  SDValue V = DAG.getNode(ISD, DL, MVT::i64, Src.getOperand(0));
+  EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
+  V = DAG.getZExtOrTrunc(V, DL, T);
+  return DAG.getBitcast(VT, V);
+}
+
 static SDValue performORCombine(SDNode *N, SelectionDAG &DAG,
                                 TargetLowering::DAGCombinerInfo &DCI,
                                 const LoongArchSubtarget &Subtarget) {
@@ -5303,6 +5393,8 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
     return performSETCCCombine(N, DAG, DCI, Subtarget);
   case ISD::SRL:
     return performSRLCombine(N, DAG, DCI, Subtarget);
+  case ISD::BITCAST:
+    return performBITCASTCombine(N, DAG, DCI, Subtarget);
   case LoongArchISD::BITREV_W:
     return performBITREV_WCombine(N, DAG, DCI, Subtarget);
   case ISD::INTRINSIC_WO_CHAIN:
@@ -5589,6 +5681,120 @@ static MachineBasicBlock *emitPseudoCTPOP(MachineInstr &MI,
   return BB;
 }
 
+static MachineBasicBlock *
+emitPseudoVMSKCOND(MachineInstr &MI, MachineBasicBlock *BB,
+                   const LoongArchSubtarget &Subtarget) {
+  const TargetInstrInfo *TII = Subtarget.getInstrInfo();
+  const TargetRegisterClass *RC = &LoongArch::LSX128RegClass;
+  const LoongArchRegisterInfo *TRI = Subtarget.getRegisterInfo();
+  MachineRegisterInfo &MRI = BB->getParent()->getRegInfo();
+  Register Dst = MI.getOperand(0).getReg();
+  Register Src = MI.getOperand(1).getReg();
+  DebugLoc DL = MI.getDebugLoc();
+  unsigned EleBits = 8;
+  unsigned NotOpc = 0;
+  unsigned MskOpc;
+
+  switch (MI.getOpcode()) {
+  default:
+    llvm_unreachable("Unexpected opcode");
+  case LoongArch::PseudoVMSKLTZ_B:
+    MskOpc = LoongArch::VMSKLTZ_B;
+    break;
+  case LoongArch::PseudoVMSKLTZ_H:
+    MskOpc = LoongArch::VMSKLTZ_H;
+    EleBits = 16;
+    break;
+  case LoongArch::PseudoVMSKLTZ_W:
+    MskOpc = LoongArch::VMSKLTZ_W;
+    EleBits = 32;
+    break;
+  case LoongArch::PseudoVMSKLTZ_D:
+    MskOpc = LoongArch::VMSKLTZ_D;
+    EleBits = 64;
+    break;
+  case LoongArch::PseudoVMSKGEZ_B:
+    MskOpc = LoongArch::VMSKGEZ_B;
+    break;
+  case LoongArch::PseudoVMSKEQZ_B:
+    MskOpc = LoongArch::VMSKNZ_B;
+    NotOpc = LoongArch::VNOR_V;
+    break;
+  case LoongArch::PseudoVMSKNEZ_B:
+    MskOpc = LoongArch::VMSKNZ_B;
+    break;
+  case LoongArch::PseudoXVMSKLTZ_B:
+    MskOpc = LoongArch::XVMSKLTZ_B;
+    RC = &LoongArch::LASX256RegClass;
+    break;
+  case LoongArch::PseudoXVMSKLTZ_H:
+    MskOpc = LoongArch::XVMSKLTZ_H;
+    RC = &LoongArch::LASX256RegClass;
+    EleBits = 16;
+    break;
+  case LoongArch::PseudoXVMSKLTZ_W:
+    MskOpc = LoongArch::XVMSKLTZ_W;
+    RC = &LoongArch::LASX256RegClass;
+    EleBits = 32;
+    break;
+  case LoongArch::PseudoXVMSKLTZ_D:
+    MskOpc = LoongArch::XVMSKLTZ_D;
+    RC = &LoongArch::LASX256RegClass;
+    EleBits = 64;
+    break;
+  case LoongArch::PseudoXVMSKGEZ_B:
+    MskOpc = LoongArch::XVMSKGEZ_B;
+    RC = &LoongArch::LASX256RegClass;
+    break;
+  case LoongArch::PseudoXVMSKEQZ_B:
+    MskOpc = LoongArch::XVMSKNZ_B;
+    NotOpc = LoongArch::XVNOR_V;
+    RC = &LoongArch::LASX256RegClass;
+    break;
+  case LoongArch::PseudoXVMSKNEZ_B:
+    MskOpc = LoongArch::XVMSKNZ_B;
+    RC = &LoongArch::LASX256RegClass;
+    break;
+  }
+
+  Register Msk = MRI.createVirtualRegister(RC);
+  if (NotOpc) {
+    Register Tmp = MRI.createVirtualRegister(RC);
+    BuildMI(*BB, MI, DL, TII->get(MskOpc), Tmp).addReg(Src);
+    BuildMI(*BB, MI, DL, TII->get(NotOpc), Msk)
+        .addReg(Tmp, RegState::Kill)
+        .addReg(Tmp, RegState::Kill);
+  } else {
+    BuildMI(*BB, MI, DL, TII->get(MskOpc), Msk).addReg(Src);
+  }
+
+  if (TRI->getRegSizeInBits(*RC) > 128) {
+    Register Lo = MRI.createVirtualRegister(&LoongArch::GPRRegClass);
+    Register Hi = MRI.createVirtualRegister(&LoongArch::GPRRegClass);
+    BuildMI(*BB, MI, DL, TII->get(LoongArch::XVPICKVE2GR_WU), Lo)
+        .addReg(Msk, RegState::Kill)
+        .addImm(0);
+    BuildMI(*BB, MI, DL, TII->get(LoongArch::XVPICKVE2GR_WU), Hi)
+        .addReg(Msk, RegState::Kill)
+        .addImm(4);
+    BuildMI(*BB, MI, DL,
+            TII->get(Subtarget.is64Bit() ? LoongArch::BSTRINS_D
+                                         : LoongArch::BSTRINS_W),
+            Dst)
+        .addReg(Lo, RegState::Kill)
+        .addReg(Hi, RegState::Kill)
+        .addImm(256 / EleBits - 1)
+        .addImm(128 / EleBits);
+  } else {
+    BuildMI(*BB, MI, DL, TII->get(LoongArch::VPICKVE2GR_HU), Dst)
+        .addReg(Msk, RegState::Kill)
+        .addImm(0);
+  }
+
+  MI.eraseFromParent();
+  return BB;
+}
+
 static bool isSelectPseudo(MachineInstr &MI) {
   switch (MI.getOpcode()) {
   default:
@@ -5795,6 +6001,21 @@ MachineBasicBlock *LoongArchTargetLowering::EmitInstrWithCustomInserter(
     return emitPseudoXVINSGR2VR(MI, BB, Subtarget);
   case LoongArch::PseudoCTPOP:
     return emitPseudoCTPOP(MI, BB, Subtarget);
+  case LoongArch::PseudoVMSKLTZ_B:
+  case LoongArch::PseudoVMSKLTZ_H:
+  case LoongArch::PseudoVMSKLTZ_W:
+  case LoongArch::PseudoVMSKLTZ_D:
+  case LoongArch::PseudoVMSKGEZ_B:
+  case LoongArch::PseudoVMSKEQZ_B:
+  case LoongArch::PseudoVMSKNEZ_B:
+  case LoongArch::PseudoXVMSKLTZ_B:
+  case LoongArch::PseudoXVMSKLTZ_H:
+  case LoongArch::PseudoXVMSKLTZ_W:
+  case LoongArch::PseudoXVMSKLTZ_D:
+  case LoongArch::PseudoXVMSKGEZ_B:
+  case LoongArch::PseudoXVMSKEQZ_B:
+  case LoongArch::PseudoXVMSKNEZ_B:
+    return emitPseudoVMSKCOND(MI, BB, Subtarget);
   case TargetOpcode::STATEPOINT:
     // STATEPOINT is a pseudo instruction which has no implicit defs/uses
     // while bl call instruction (where statepoint will be lowered at the
@@ -5916,6 +6137,14 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
     NODE_NAME_CASE(VBSLL)
     NODE_NAME_CASE(VBSRL)
     NODE_NAME_CASE(VLDREPL)
+    NODE_NAME_CASE(VMSKLTZ)
+    NODE_NAME_CASE(VMSKGEZ)
+    NODE_NAME_CASE(VMSKEQZ)
+    NODE_NAME_CASE(VMSKNEZ)
+    NODE_NAME_CASE(XVMSKLTZ)
+    NODE_NAME_CASE(XVMSKGEZ)
+    NODE_NAME_CASE(XVMSKEQZ)
+    NODE_NAME_CASE(XVMSKNEZ)
   }
 #undef NODE_NAME_CASE
   return nullptr;
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
index 6bf295984dfc5..d22b635c5b187 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.h
@@ -161,7 +161,17 @@ enum NodeType : unsigned {
   VBSRL,
 
   // Scalar load broadcast to vector
-  VLDREPL
+  VLDREPL,
+
+  // Vector mask set by condition
+  VMSKLTZ,
+  VMSKGEZ,
+  VMSKEQZ,
+  VMSKNEZ,
+  XVMSKLTZ,
+  XVMSKGEZ,
+  XVMSKEQZ,
+  XVMSKNEZ,
 
   // Intrinsic operations end =============================================
 };
diff --git a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
index 59dec76ef1c2d..ff7b0f2ae3f25 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLASXInstrInfo.td
@@ -12,6 +12,10 @@
 
 // Target nodes.
 def loongarch_xvpermi: SDNode<"LoongArchISD::XVPERMI", SDT_LoongArchV1RUimm>;
+def loongarch_xvmskltz: SDNode<"LoongArchISD::XVMSKLTZ", SDT_LoongArchVMSKCOND>;
+def loongarch_xvmskgez: SDNode<"LoongArchISD::XVMSKGEZ", SDT_LoongArchVMSKCOND>;
+def loongarch_xvmskeqz: SDNode<"LoongArchISD::XVMSKEQZ", SDT_LoongArchVMSKCOND>;
+def loongarch_xvmsknez: SDNode<"LoongArchISD::XVMSKNEZ", SDT_LoongArchVMSKCOND>;
 
 def lasxsplati8
   : PatFrag<(ops node:$e0),
@@ -1086,6 +1090,16 @@ def PseudoXVINSGR2VR_H
   : Pseudo<(outs LASX256:$dst), (ins LASX256:$xd, GPR:$rj, uimm4:$imm)>;
 } //  usesCustomInserter = 1, Constraints = "$xd = $dst"
 
+let usesCustomInserter = 1, hasSideEffects = 0, mayLoad = 0, mayStore = 0 in {
+def PseudoXVMSKLTZ_B : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
+def PseudoXVMSKLTZ_H : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
+def PseudoXVMSKLTZ_W : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
+def PseudoXVMSKLTZ_D : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
+def PseudoXVMSKGEZ_B : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
+def PseudoXVMSKEQZ_B : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
+def PseudoXVMSKNEZ_B : Pseudo<(outs GPR:$rd), (ins LASX256:$vj)>;
+} // usesCustomInserter = 1, hasSideEffects = 0, mayLoad = 0, mayStore = 0
+
 } // Predicates = [HasExtLASX]
 
 multiclass PatXr<SDPatternOperator OpNode, string Inst> {
@@ -1856,6 +1870,15 @@ def : Pat<(vt (concat_vectors LSX128:$vd, LSX128:$vj)),
 defm : PatXrXr<abds, "XVABSD">;
 defm : PatXrXrU<abdu, "XVABSD">;
 
+// Vector mask set by condition
+def : Pat<(loongarch_xvmskltz (v32i8 LASX256:$vj)), (PseudoXVMSKLTZ_B LASX256:$vj)>;
+def : Pat<(loongarch_xvmskltz (v16i16 LASX256:$vj)), (PseudoXVMSKLTZ_H LASX256:$vj)>;
+def : Pat<(loongarch_xvmskltz (v8i32 LASX256:$vj)), (PseudoXVMSKLTZ_W LASX256:$vj)>;
+def : Pat<(loongarch_xvmskltz (v4i64 LASX256:$vj)), (PseudoXVMSKLTZ_D LASX256:$vj)>;
+def : Pat<(loongarch_xvmskgez (v32i8 LASX256:$vj)), (PseudoXVMSKGEZ_B LASX256:$vj)>;
+def : Pat<(loongarch_xvmskeqz (v32i8 LASX256:$vj)), (PseudoXVMSKEQZ_B LASX256:$vj)>;
+def : Pat<(loongarch_xvmsknez (v32i8 LASX256:$vj)), (PseudoXVMSKNEZ_B LASX256:$vj)>;
+
 } // Predicates = [HasExtLASX]
 
 /// Intrinsic pattern
diff --git a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
index b2b3b65155265..d73d78083ddcd 100644
--- a/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
+++ b/llvm/lib/Target/LoongArch/LoongArchLSXInstrInfo.td
@@ -31,6 +31,7 @@ def SDT_LoongArchVreplgr2vr : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<0>, S
 def SDT_LoongArchVFRECIPE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
 def SDT_LoongArchVFRSQRTE : SDTypeProfile<1, 1, [SDTCisFP<0>, SDTCisVec<0>, SDTCisSameAs<0, 1>]>;
 def SDT_LoongArchVLDREPL : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisPtrTy<1>]>;
+def SDT_LoongArchVMSKCOND : SDTypeProfile<1, 1, [SDTCisInt<0>, SDTCisVec<1>]>;
 
 // Target nodes.
 def loongarch_vreplve : SDNode<"LoongArchISD::VREPLVE", SDT_LoongArchVreplve>;
@@ -74,6 +75,11 @@ def loongarch_vldrepl
     : SDNode<"LoongArchISD::VLDREPL",
              SDT_LoongArchVLDREPL, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>;
 
+def loongarch_vmskltz: SDNode<"LoongArchISD::VMSKLTZ", SDT_LoongArchVMSKCOND>;
+def loongarch_vmskgez: SDNode<"LoongArchISD::VMSKGEZ", SDT_LoongArchVMSKCOND>;
+def loongarch_vmskeqz: SDNode<"LoongArchISD::VMSKEQZ", SDT_LoongArchVMSKCOND>;
+def loongarch_vmsknez: SDNode<"LoongArchISD::VMSKNEZ", SDT_LoongArchVMSKCOND>;
+
 def immZExt1 : ImmLeaf<i64, [{return isUInt<1>(Imm);}]>;
 def immZExt2 : ImmLeaf<i64, [{return isUInt<2>(Imm);}]>;
 def immZExt3 : ImmLeaf<i64, [{return isUInt<3>(Imm);}]>;
@@ -1266,6 +1272,16 @@ let usesCustomInserter = 1 in
 def PseudoCTPOP : Pseudo<(outs GPR:$rd), (ins GPR:$rj),
                          [(set GPR:$rd, (ctpop GPR:$rj))]>;
 
+let usesCustomInserter = 1, hasSideEffects = 0, mayLoad = 0, mayStore = 0 in {
+def PseudoVMSKLTZ_B : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
+def PseudoVMSKLTZ_H : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
+def PseudoVMSKLTZ_W : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
+def PseudoVMSKLTZ_D : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
+def PseudoVMSKGEZ_B : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
+def PseudoVMSKEQZ_B : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
+def PseudoVMSKNEZ_B : Pseudo<(outs GPR:$rd), (ins LSX128:$vj)>;
+} // usesCustomInserter = 1, hasSideEffects = 0, mayLoad = 0, mayStore = 0
+
 } // Predicates = [HasExtLSX]
 
 multiclass PatVr<SDPatternOperator OpNode, string Inst> {
@@ -2050,6 +2066,15 @@ def : Pat<(f64 f64imm_vldi:$in),
 defm : PatVrVr<abds, "VABSD">;
 defm : PatVrVrU<abdu, "VABSD">;
 
+// Vector mask set by condition
+def : Pat<(loongarch_vmskltz (v16i8 LSX128:$vj)), (PseudoVMSKLTZ_B LSX128:$vj)>;
+def : Pat<(loongarch_vmskltz (v8i16 LSX128:$vj)), (PseudoVMSKLTZ_H LSX128:$vj)>;
+def : Pat<(loongarch_vmskltz (v4i32 LSX128:$vj)), (PseudoVMSKLTZ_W LSX128:$vj)>;
+def : Pat<(loongarch_vmskltz (v2i64 LSX128:$vj)), (PseudoVMSKLTZ_D LSX128:$vj)>;
+def : Pat<(loongarch_vmskgez (v16i8 LSX128:$vj)), (PseudoVMSKGEZ_B LSX128:$vj)>;
+def : Pat<(loongarch_vmskeqz (v16i8 LSX128:$vj)), (PseudoVMSKEQZ_B LSX128:$vj)>;
+def : Pat<(loongarch_vmsknez (v16i8 LSX128:$vj)), (PseudoVMSKNEZ_B LSX128:$vj)>;
+
 } // Predicates = [HasExtLSX]
 
 /// Intrinsic pattern
diff --git a/llvm/test/CodeGen/LoongArch/lasx/xvmskcond.ll b/llvm/test/CodeGen/LoongArch/lasx/xvmskcond.ll
new file mode 100644
index 0000000000000..93b081e450985
--- /dev/null
+++ b/llvm/test/CodeGen/LoongArch/lasx/xvmskcond.ll
@@ -0,0 +1,200 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc --mtriple=loongarch64 --mattr=+lasx < %s | FileCheck %s
+
+define i32 @xvmsk_eq_allzeros_i8(<32 x i8 > %a) {
+; CHECK-LABEL: xvmsk_eq_allzeros_i8:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvmsknz.b $xr0, $xr0
+; CHECK-NEXT:    xvnor.v $xr0, $xr0, $xr0
+; CHECK-NEXT:    xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT:    xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT:    bstrins.d $a0, $a1, 31, 16
+; CHECK-NEXT:    ret
+entry:
+  %1 = icmp eq <32 x i8> %a, splat (i8 0)
+  %2 = bitcast <32 x i1> %1 to i32
+  ret i32 %2
+}
+
+define i32 @xvmsk_eq_allones_i8(<32 x i8 > %a) {
+; CHECK-LABEL: xvmsk_eq_allones_i8:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvmsknz.b $xr0, $xr0
+; CHECK-NEXT:    xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT:    xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT:    bstrins.d $a0, $a1, 31, 16
+; CHECK-NEXT:    ret
+entry:
+  %1 = icmp eq <32 x i8> %a, splat (i8 -1)
+  %2 = bitcast <32 x i1> %1 to i32
+  ret i32 %2
+}
+
+define i32 @xvmsk_sgt_allones_i8(<32 x i8 > %a) {
+; CHECK-LABEL: xvmsk_sgt_allones_i8:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvmskgez.b $xr0, $xr0
+; CHECK-NEXT:    xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT:    xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT:    bstrins.d $a0, $a1, 31, 16
+; CHECK-NEXT:    ret
+entry:
+  %1 = icmp sgt <32 x i8> %a, splat (i8 -1)
+  %2 = bitcast <32 x i1> %1 to i32
+  ret i32 %2
+}
+
+define i32 @xvmsk_sge_allzeros_i8(<32 x i8 > %a) {
+; CHECK-LABEL: xvmsk_sge_allzeros_i8:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvmskgez.b $xr0, $xr0
+; CHECK-NEXT:    xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT:    xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT:    bstrins.d $a0, $a1, 31, 16
+; CHECK-NEXT:    ret
+entry:
+  %1 = icmp sge <32 x i8> %a, splat (i8 0)
+  %2 = bitcast <32 x i1> %1 to i32
+  ret i32 %2
+}
+
+define i32 @xvmsk_slt_allzeros_i8(<32 x i8 > %a) {
+; CHECK-LABEL: xvmsk_slt_allzeros_i8:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvmskltz.b $xr0, $xr0
+; CHECK-NEXT:    xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT:    xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT:    bstrins.d $a0, $a1, 31, 16
+; CHECK-NEXT:    ret
+entry:
+  %1 = icmp slt <32 x i8> %a, splat (i8 0)
+  %2 = bitcast <32 x i1> %1 to i32
+  ret i32 %2
+}
+
+define i16 @xvmsk_slt_allzeros_i16(<16 x i16 > %a) {
+; CHECK-LABEL: xvmsk_slt_allzeros_i16:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvmskltz.h $xr0, $xr0
+; CHECK-NEXT:    xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT:    xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT:    bstrins.d $a0, $a1, 15, 8
+; CHECK-NEXT:    ret
+entry:
+  %1 = icmp slt <16 x i16> %a, splat (i16 0)
+  %2 = bitcast <16 x i1> %1 to i16
+  ret i16 %2
+}
+
+define i8 @xvmsk_slt_allzeros_i32(<8 x i32 > %a) {
+; CHECK-LABEL: xvmsk_slt_allzeros_i32:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvmskltz.w $xr0, $xr0
+; CHECK-NEXT:    xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT:    xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT:    bstrins.d $a0, $a1, 7, 4
+; CHECK-NEXT:    ret
+entry:
+  %1 = icmp slt <8 x i32> %a, splat (i32 0)
+  %2 = bitcast <8 x i1> %1 to i8
+  ret i8 %2
+}
+
+define i4 @xvmsk_slt_allzeros_i64(<4 x i64 > %a) {
+; CHECK-LABEL: xvmsk_slt_allzeros_i64:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvmskltz.d $xr0, $xr0
+; CHECK-NEXT:    xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT:    xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT:    bstrins.d $a0, $a1, 3, 2
+; CHECK-NEXT:    ret
+entry:
+  %1 = icmp slt <4 x i64> %a, splat (i64 0)
+  %2 = bitcast <4 x i1> %1 to i4
+  ret i4 %2
+}
+
+define i32 @xvmsk_sle_allones_i8(<32 x i8 > %a) {
+; CHECK-LABEL: xvmsk_sle_allones_i8:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    xvmskltz.b $xr0, $xr0
+; CHECK-NEXT:    xvpickve2gr.wu $a0, $xr0, 0
+; CHECK-NEXT:    xvpickve2gr.wu $a1, $xr0, 4
+; CHECK-NEXT:    bstrins.d $a0, $a1, 31, 16
+; CHECK-NEXT:    ret
+entry:
+  %1 = icmp sle <32 x i8> %a, splat (i8 -1)
+  %2 = bit...
[truncated]

@tangaac
Copy link
Contributor

tangaac commented Jun 3, 2025

For byte we have >=0 and !=0, they can combine to >0.

This patch seems to have no impact on the llvm-test-suite

heiher added 2 commits June 3, 2025 11:56
…Z` if possible

This patch adds a DAG combine rule for BITCAST nodes converting from
vector `i1` masks generated by `setcc` into integer vector types. It
recognizes common select mask patterns and lowers them into efficient
LoongArch LSX/LASX mask instructions such as:

- [X]VMSKLTZ.{B,H,W,D}
- [X]VMSKGEZ.B
- [X]VMSKNEZ.B

When the vector comparison matches specific patterns (e.g., x < 0,
x >= 0, x != 0, etc.), the transformation is performed pre-legalization.
This avoids scalarization and unnecessary operations, improving both
performance and code size.
@heiher
Copy link
Member Author

heiher commented Jun 3, 2025

For byte we have >=0 and !=0, they can combine to >0.

We also have V{SEQ, SLE, SLT}, the next is to support combinations with VMSK{LT, GE, NE}. For example, the expression x > 0 can be transformed as:

not (x <= 0)

vslei.b t, x, 0
vmskgez.b t, t

Copy link
Contributor

@SixWeining SixWeining left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. @wangleiat What do you think?

@heiher heiher merged commit 2718a47 into llvm:main Jun 5, 2025
11 checks passed
@heiher heiher deleted the vmsk-2 branch June 5, 2025 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants