@@ -391,8 +391,10 @@ LoongArchTargetLowering::LoongArchTargetLowering(const TargetMachine &TM,
391
391
392
392
// Set DAG combine for 'LSX' feature.
393
393
394
- if (Subtarget.hasExtLSX ())
394
+ if (Subtarget.hasExtLSX ()) {
395
395
setTargetDAGCombine (ISD::INTRINSIC_WO_CHAIN);
396
+ setTargetDAGCombine (ISD::BITCAST);
397
+ }
396
398
397
399
// Compute derived properties from the register classes.
398
400
computeRegisterProperties (Subtarget.getRegisterInfo ());
@@ -4329,6 +4331,85 @@ static SDValue performSRLCombine(SDNode *N, SelectionDAG &DAG,
4329
4331
return SDValue ();
4330
4332
}
4331
4333
4334
+ static SDValue performBITCASTCombine (SDNode *N, SelectionDAG &DAG,
4335
+ TargetLowering::DAGCombinerInfo &DCI,
4336
+ const LoongArchSubtarget &Subtarget) {
4337
+ SDLoc DL (N);
4338
+ EVT VT = N->getValueType (0 );
4339
+ SDValue Src = N->getOperand (0 );
4340
+ EVT SrcVT = Src.getValueType ();
4341
+
4342
+ if (!DCI.isBeforeLegalizeOps ())
4343
+ return SDValue ();
4344
+
4345
+ if (!SrcVT.isSimple () || SrcVT.getScalarType () != MVT::i1)
4346
+ return SDValue ();
4347
+
4348
+ unsigned Opc = ISD::DELETED_NODE;
4349
+ // Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible
4350
+ if (Src.getOpcode () == ISD::SETCC && Src.hasOneUse ()) {
4351
+ bool UseLASX;
4352
+ EVT CmpVT = Src.getOperand (0 ).getValueType ();
4353
+ EVT EltVT = CmpVT.getVectorElementType ();
4354
+
4355
+ if (Subtarget.hasExtLSX () && CmpVT.getSizeInBits () <= 128 )
4356
+ UseLASX = false ;
4357
+ else if (Subtarget.has32S () && Subtarget.hasExtLASX () &&
4358
+ CmpVT.getSizeInBits () <= 256 )
4359
+ UseLASX = true ;
4360
+ else
4361
+ return SDValue ();
4362
+
4363
+ SDValue SrcN1 = Src.getOperand (1 );
4364
+ switch (cast<CondCodeSDNode>(Src.getOperand (2 ))->get ()) {
4365
+ default :
4366
+ break ;
4367
+ case ISD::SETEQ:
4368
+ // x == 0 => not (vmsknez.b x)
4369
+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4370
+ Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
4371
+ break ;
4372
+ case ISD::SETGT:
4373
+ // x > -1 => vmskgez.b x
4374
+ if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) && EltVT == MVT::i8 )
4375
+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4376
+ break ;
4377
+ case ISD::SETGE:
4378
+ // x >= 0 => vmskgez.b x
4379
+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4380
+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
4381
+ break ;
4382
+ case ISD::SETLT:
4383
+ // x < 0 => vmskltz.{b,h,w,d} x
4384
+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) &&
4385
+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4386
+ EltVT == MVT::i64 ))
4387
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4388
+ break ;
4389
+ case ISD::SETLE:
4390
+ // x <= -1 => vmskltz.{b,h,w,d} x
4391
+ if (ISD::isBuildVectorAllOnes (SrcN1.getNode ()) &&
4392
+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
4393
+ EltVT == MVT::i64 ))
4394
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4395
+ break ;
4396
+ case ISD::SETNE:
4397
+ // x != 0 => vmsknez.b x
4398
+ if (ISD::isBuildVectorAllZeros (SrcN1.getNode ()) && EltVT == MVT::i8 )
4399
+ Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
4400
+ break ;
4401
+ }
4402
+ }
4403
+
4404
+ if (Opc == ISD::DELETED_NODE)
4405
+ return SDValue ();
4406
+
4407
+ SDValue V = DAG.getNode (Opc, DL, MVT::i64 , Src.getOperand (0 ));
4408
+ EVT T = EVT::getIntegerVT (*DAG.getContext (), SrcVT.getVectorNumElements ());
4409
+ V = DAG.getZExtOrTrunc (V, DL, T);
4410
+ return DAG.getBitcast (VT, V);
4411
+ }
4412
+
4332
4413
static SDValue performORCombine (SDNode *N, SelectionDAG &DAG,
4333
4414
TargetLowering::DAGCombinerInfo &DCI,
4334
4415
const LoongArchSubtarget &Subtarget) {
@@ -5373,6 +5454,8 @@ SDValue LoongArchTargetLowering::PerformDAGCombine(SDNode *N,
5373
5454
return performSETCCCombine (N, DAG, DCI, Subtarget);
5374
5455
case ISD::SRL:
5375
5456
return performSRLCombine (N, DAG, DCI, Subtarget);
5457
+ case ISD::BITCAST:
5458
+ return performBITCASTCombine (N, DAG, DCI, Subtarget);
5376
5459
case LoongArchISD::BITREV_W:
5377
5460
return performBITREV_WCombine (N, DAG, DCI, Subtarget);
5378
5461
case ISD::INTRINSIC_WO_CHAIN:
@@ -5663,6 +5746,120 @@ static MachineBasicBlock *emitPseudoCTPOP(MachineInstr &MI,
5663
5746
return BB;
5664
5747
}
5665
5748
5749
+ static MachineBasicBlock *
5750
+ emitPseudoVMSKCOND (MachineInstr &MI, MachineBasicBlock *BB,
5751
+ const LoongArchSubtarget &Subtarget) {
5752
+ const TargetInstrInfo *TII = Subtarget.getInstrInfo ();
5753
+ const TargetRegisterClass *RC = &LoongArch::LSX128RegClass;
5754
+ const LoongArchRegisterInfo *TRI = Subtarget.getRegisterInfo ();
5755
+ MachineRegisterInfo &MRI = BB->getParent ()->getRegInfo ();
5756
+ Register Dst = MI.getOperand (0 ).getReg ();
5757
+ Register Src = MI.getOperand (1 ).getReg ();
5758
+ DebugLoc DL = MI.getDebugLoc ();
5759
+ unsigned EleBits = 8 ;
5760
+ unsigned NotOpc = 0 ;
5761
+ unsigned MskOpc;
5762
+
5763
+ switch (MI.getOpcode ()) {
5764
+ default :
5765
+ llvm_unreachable (" Unexpected opcode" );
5766
+ case LoongArch::PseudoVMSKLTZ_B:
5767
+ MskOpc = LoongArch::VMSKLTZ_B;
5768
+ break ;
5769
+ case LoongArch::PseudoVMSKLTZ_H:
5770
+ MskOpc = LoongArch::VMSKLTZ_H;
5771
+ EleBits = 16 ;
5772
+ break ;
5773
+ case LoongArch::PseudoVMSKLTZ_W:
5774
+ MskOpc = LoongArch::VMSKLTZ_W;
5775
+ EleBits = 32 ;
5776
+ break ;
5777
+ case LoongArch::PseudoVMSKLTZ_D:
5778
+ MskOpc = LoongArch::VMSKLTZ_D;
5779
+ EleBits = 64 ;
5780
+ break ;
5781
+ case LoongArch::PseudoVMSKGEZ_B:
5782
+ MskOpc = LoongArch::VMSKGEZ_B;
5783
+ break ;
5784
+ case LoongArch::PseudoVMSKEQZ_B:
5785
+ MskOpc = LoongArch::VMSKNZ_B;
5786
+ NotOpc = LoongArch::VNOR_V;
5787
+ break ;
5788
+ case LoongArch::PseudoVMSKNEZ_B:
5789
+ MskOpc = LoongArch::VMSKNZ_B;
5790
+ break ;
5791
+ case LoongArch::PseudoXVMSKLTZ_B:
5792
+ MskOpc = LoongArch::XVMSKLTZ_B;
5793
+ RC = &LoongArch::LASX256RegClass;
5794
+ break ;
5795
+ case LoongArch::PseudoXVMSKLTZ_H:
5796
+ MskOpc = LoongArch::XVMSKLTZ_H;
5797
+ RC = &LoongArch::LASX256RegClass;
5798
+ EleBits = 16 ;
5799
+ break ;
5800
+ case LoongArch::PseudoXVMSKLTZ_W:
5801
+ MskOpc = LoongArch::XVMSKLTZ_W;
5802
+ RC = &LoongArch::LASX256RegClass;
5803
+ EleBits = 32 ;
5804
+ break ;
5805
+ case LoongArch::PseudoXVMSKLTZ_D:
5806
+ MskOpc = LoongArch::XVMSKLTZ_D;
5807
+ RC = &LoongArch::LASX256RegClass;
5808
+ EleBits = 64 ;
5809
+ break ;
5810
+ case LoongArch::PseudoXVMSKGEZ_B:
5811
+ MskOpc = LoongArch::XVMSKGEZ_B;
5812
+ RC = &LoongArch::LASX256RegClass;
5813
+ break ;
5814
+ case LoongArch::PseudoXVMSKEQZ_B:
5815
+ MskOpc = LoongArch::XVMSKNZ_B;
5816
+ NotOpc = LoongArch::XVNOR_V;
5817
+ RC = &LoongArch::LASX256RegClass;
5818
+ break ;
5819
+ case LoongArch::PseudoXVMSKNEZ_B:
5820
+ MskOpc = LoongArch::XVMSKNZ_B;
5821
+ RC = &LoongArch::LASX256RegClass;
5822
+ break ;
5823
+ }
5824
+
5825
+ Register Msk = MRI.createVirtualRegister (RC);
5826
+ if (NotOpc) {
5827
+ Register Tmp = MRI.createVirtualRegister (RC);
5828
+ BuildMI (*BB, MI, DL, TII->get (MskOpc), Tmp).addReg (Src);
5829
+ BuildMI (*BB, MI, DL, TII->get (NotOpc), Msk)
5830
+ .addReg (Tmp, RegState::Kill)
5831
+ .addReg (Tmp, RegState::Kill);
5832
+ } else {
5833
+ BuildMI (*BB, MI, DL, TII->get (MskOpc), Msk).addReg (Src);
5834
+ }
5835
+
5836
+ if (TRI->getRegSizeInBits (*RC) > 128 ) {
5837
+ Register Lo = MRI.createVirtualRegister (&LoongArch::GPRRegClass);
5838
+ Register Hi = MRI.createVirtualRegister (&LoongArch::GPRRegClass);
5839
+ BuildMI (*BB, MI, DL, TII->get (LoongArch::XVPICKVE2GR_WU), Lo)
5840
+ .addReg (Msk)
5841
+ .addImm (0 );
5842
+ BuildMI (*BB, MI, DL, TII->get (LoongArch::XVPICKVE2GR_WU), Hi)
5843
+ .addReg (Msk, RegState::Kill)
5844
+ .addImm (4 );
5845
+ BuildMI (*BB, MI, DL,
5846
+ TII->get (Subtarget.is64Bit () ? LoongArch::BSTRINS_D
5847
+ : LoongArch::BSTRINS_W),
5848
+ Dst)
5849
+ .addReg (Lo, RegState::Kill)
5850
+ .addReg (Hi, RegState::Kill)
5851
+ .addImm (256 / EleBits - 1 )
5852
+ .addImm (128 / EleBits);
5853
+ } else {
5854
+ BuildMI (*BB, MI, DL, TII->get (LoongArch::VPICKVE2GR_HU), Dst)
5855
+ .addReg (Msk, RegState::Kill)
5856
+ .addImm (0 );
5857
+ }
5858
+
5859
+ MI.eraseFromParent ();
5860
+ return BB;
5861
+ }
5862
+
5666
5863
static bool isSelectPseudo (MachineInstr &MI) {
5667
5864
switch (MI.getOpcode ()) {
5668
5865
default :
@@ -5869,6 +6066,21 @@ MachineBasicBlock *LoongArchTargetLowering::EmitInstrWithCustomInserter(
5869
6066
return emitPseudoXVINSGR2VR (MI, BB, Subtarget);
5870
6067
case LoongArch::PseudoCTPOP:
5871
6068
return emitPseudoCTPOP (MI, BB, Subtarget);
6069
+ case LoongArch::PseudoVMSKLTZ_B:
6070
+ case LoongArch::PseudoVMSKLTZ_H:
6071
+ case LoongArch::PseudoVMSKLTZ_W:
6072
+ case LoongArch::PseudoVMSKLTZ_D:
6073
+ case LoongArch::PseudoVMSKGEZ_B:
6074
+ case LoongArch::PseudoVMSKEQZ_B:
6075
+ case LoongArch::PseudoVMSKNEZ_B:
6076
+ case LoongArch::PseudoXVMSKLTZ_B:
6077
+ case LoongArch::PseudoXVMSKLTZ_H:
6078
+ case LoongArch::PseudoXVMSKLTZ_W:
6079
+ case LoongArch::PseudoXVMSKLTZ_D:
6080
+ case LoongArch::PseudoXVMSKGEZ_B:
6081
+ case LoongArch::PseudoXVMSKEQZ_B:
6082
+ case LoongArch::PseudoXVMSKNEZ_B:
6083
+ return emitPseudoVMSKCOND (MI, BB, Subtarget);
5872
6084
case TargetOpcode::STATEPOINT:
5873
6085
// STATEPOINT is a pseudo instruction which has no implicit defs/uses
5874
6086
// while bl call instruction (where statepoint will be lowered at the
@@ -5990,6 +6202,14 @@ const char *LoongArchTargetLowering::getTargetNodeName(unsigned Opcode) const {
5990
6202
NODE_NAME_CASE (VBSLL)
5991
6203
NODE_NAME_CASE (VBSRL)
5992
6204
NODE_NAME_CASE (VLDREPL)
6205
+ NODE_NAME_CASE (VMSKLTZ)
6206
+ NODE_NAME_CASE (VMSKGEZ)
6207
+ NODE_NAME_CASE (VMSKEQZ)
6208
+ NODE_NAME_CASE (VMSKNEZ)
6209
+ NODE_NAME_CASE (XVMSKLTZ)
6210
+ NODE_NAME_CASE (XVMSKGEZ)
6211
+ NODE_NAME_CASE (XVMSKEQZ)
6212
+ NODE_NAME_CASE (XVMSKNEZ)
5993
6213
}
5994
6214
#undef NODE_NAME_CASE
5995
6215
return nullptr ;
0 commit comments