@@ -1064,6 +1064,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
10641064 setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER,
10651065 ISD::VP_GATHER, ISD::VP_SCATTER, ISD::SRA, ISD::SRL,
10661066 ISD::SHL, ISD::STORE, ISD::SPLAT_VECTOR});
1067+ if (Subtarget.hasVendorXTHeadMemPair())
1068+ setTargetDAGCombine({ISD::LOAD, ISD::STORE});
10671069 if (Subtarget.useRVVForFixedLengthVectors())
10681070 setTargetDAGCombine(ISD::BITCAST);
10691071
@@ -9653,6 +9655,143 @@ combineBinOp_VLToVWBinOp_VL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
96539655 return InputRootReplacement;
96549656}
96559657
9658+ // Helper function for performMemPairCombine.
9659+ // Try to combine the memory loads/stores LSNode1 and LSNode2
9660+ // into a single memory pair operation.
9661+ static SDValue tryMemPairCombine(SelectionDAG &DAG, LSBaseSDNode *LSNode1,
9662+ LSBaseSDNode *LSNode2, SDValue BasePtr,
9663+ uint64_t Imm) {
9664+ SmallPtrSet<const SDNode *, 32> Visited;
9665+ SmallVector<const SDNode *, 8> Worklist = {LSNode1, LSNode2};
9666+
9667+ if (SDNode::hasPredecessorHelper(LSNode1, Visited, Worklist) ||
9668+ SDNode::hasPredecessorHelper(LSNode2, Visited, Worklist))
9669+ return SDValue();
9670+
9671+ MachineFunction &MF = DAG.getMachineFunction();
9672+ const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>();
9673+
9674+ // The new operation has twice the width.
9675+ MVT XLenVT = Subtarget.getXLenVT();
9676+ EVT MemVT = LSNode1->getMemoryVT();
9677+ EVT NewMemVT = (MemVT == MVT::i32) ? MVT::i64 : MVT::i128;
9678+ MachineMemOperand *MMO = LSNode1->getMemOperand();
9679+ MachineMemOperand *NewMMO = MF.getMachineMemOperand(
9680+ MMO, MMO->getPointerInfo(), MemVT == MVT::i32 ? 8 : 16);
9681+
9682+ if (LSNode1->getOpcode() == ISD::LOAD) {
9683+ auto Ext = cast<LoadSDNode>(LSNode1)->getExtensionType();
9684+ unsigned Opcode;
9685+ if (MemVT == MVT::i32)
9686+ Opcode = (Ext == ISD::ZEXTLOAD) ? RISCVISD::TH_LWUD : RISCVISD::TH_LWD;
9687+ else
9688+ Opcode = RISCVISD::TH_LDD;
9689+
9690+ SDValue Res = DAG.getMemIntrinsicNode(
9691+ Opcode, SDLoc(LSNode1), DAG.getVTList({XLenVT, XLenVT, MVT::Other}),
9692+ {LSNode1->getChain(), BasePtr,
9693+ DAG.getConstant(Imm, SDLoc(LSNode1), XLenVT)},
9694+ NewMemVT, NewMMO);
9695+
9696+ SDValue Node1 =
9697+ DAG.getMergeValues({Res.getValue(0), Res.getValue(2)}, SDLoc(LSNode1));
9698+ SDValue Node2 =
9699+ DAG.getMergeValues({Res.getValue(1), Res.getValue(2)}, SDLoc(LSNode2));
9700+
9701+ DAG.ReplaceAllUsesWith(LSNode2, Node2.getNode());
9702+ return Node1;
9703+ } else {
9704+ unsigned Opcode = (MemVT == MVT::i32) ? RISCVISD::TH_SWD : RISCVISD::TH_SDD;
9705+
9706+ SDValue Res = DAG.getMemIntrinsicNode(
9707+ Opcode, SDLoc(LSNode1), DAG.getVTList(MVT::Other),
9708+ {LSNode1->getChain(), LSNode1->getOperand(1), LSNode2->getOperand(1),
9709+ BasePtr, DAG.getConstant(Imm, SDLoc(LSNode1), XLenVT)},
9710+ NewMemVT, NewMMO);
9711+
9712+ DAG.ReplaceAllUsesWith(LSNode2, Res.getNode());
9713+ return Res;
9714+ }
9715+ }
9716+
9717+ // Try to combine two adjacent loads/stores to a single pair instruction from
9718+ // the XTHeadMemPair vendor extension.
9719+ static SDValue performMemPairCombine(SDNode *N,
9720+ TargetLowering::DAGCombinerInfo &DCI) {
9721+ SelectionDAG &DAG = DCI.DAG;
9722+ MachineFunction &MF = DAG.getMachineFunction();
9723+ const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>();
9724+
9725+ // Target does not support load/store pair.
9726+ if (!Subtarget.hasVendorXTHeadMemPair())
9727+ return SDValue();
9728+
9729+ LSBaseSDNode *LSNode1 = cast<LSBaseSDNode>(N);
9730+ EVT MemVT = LSNode1->getMemoryVT();
9731+ unsigned OpNum = LSNode1->getOpcode() == ISD::LOAD ? 1 : 2;
9732+
9733+ // No volatile, indexed or atomic loads/stores.
9734+ if (!LSNode1->isSimple() || LSNode1->isIndexed())
9735+ return SDValue();
9736+
9737+ // Function to get a base + constant representation from a memory value.
9738+ auto ExtractBaseAndOffset = [](SDValue Ptr) -> std::pair<SDValue, uint64_t> {
9739+ if (Ptr->getOpcode() == ISD::ADD)
9740+ if (auto *C1 = dyn_cast<ConstantSDNode>(Ptr->getOperand(1)))
9741+ return {Ptr->getOperand(0), C1->getZExtValue()};
9742+ return {Ptr, 0};
9743+ };
9744+
9745+ auto [Base1, Offset1] = ExtractBaseAndOffset(LSNode1->getOperand(OpNum));
9746+
9747+ SDValue Chain = N->getOperand(0);
9748+ for (SDNode::use_iterator UI = Chain->use_begin(), UE = Chain->use_end();
9749+ UI != UE; ++UI) {
9750+ SDUse &Use = UI.getUse();
9751+ if (Use.getUser() != N && Use.getResNo() == 0 &&
9752+ Use.getUser()->getOpcode() == N->getOpcode()) {
9753+ LSBaseSDNode *LSNode2 = cast<LSBaseSDNode>(Use.getUser());
9754+
9755+ // No volatile, indexed or atomic loads/stores.
9756+ if (!LSNode2->isSimple() || LSNode2->isIndexed())
9757+ continue;
9758+
9759+ // Check if LSNode1 and LSNode2 have the same type and extension.
9760+ if (LSNode1->getOpcode() == ISD::LOAD)
9761+ if (cast<LoadSDNode>(LSNode2)->getExtensionType() !=
9762+ cast<LoadSDNode>(LSNode1)->getExtensionType())
9763+ continue;
9764+
9765+ if (LSNode1->getMemoryVT() != LSNode2->getMemoryVT())
9766+ continue;
9767+
9768+ auto [Base2, Offset2] = ExtractBaseAndOffset(LSNode2->getOperand(OpNum));
9769+
9770+ // Check if the base pointer is the same for both instruction.
9771+ if (Base1 != Base2)
9772+ continue;
9773+
9774+ // Check if the offsets match the XTHeadMemPair encoding contraints.
9775+ if (MemVT == MVT::i32) {
9776+ // Check for adjacent i32 values and a 2-bit index.
9777+ if ((Offset1 + 4 != Offset2) || !isShiftedUInt<2, 3>(Offset1))
9778+ continue;
9779+ } else if (MemVT == MVT::i64) {
9780+ // Check for adjacent i64 values and a 2-bit index.
9781+ if ((Offset1 + 8 != Offset2) || !isShiftedUInt<2, 4>(Offset1))
9782+ continue;
9783+ }
9784+
9785+ // Try to combine.
9786+ if (SDValue Res =
9787+ tryMemPairCombine(DAG, LSNode1, LSNode2, Base1, Offset1))
9788+ return Res;
9789+ }
9790+ }
9791+
9792+ return SDValue();
9793+ }
9794+
96569795// Fold
96579796// (fp_to_int (froundeven X)) -> fcvt X, rne
96589797// (fp_to_int (ftrunc X)) -> fcvt X, rtz
@@ -10622,7 +10761,15 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1062210761 return DAG.getNode(NewOpcode, SDLoc(N), N->getValueType(0), A, B, C, Mask,
1062310762 VL);
1062410763 }
10764+ case ISD::LOAD:
1062510765 case ISD::STORE: {
10766+ if (DCI.isAfterLegalizeDAG())
10767+ if (SDValue V = performMemPairCombine(N, DCI))
10768+ return V;
10769+
10770+ if (N->getOpcode() != ISD::STORE)
10771+ break;
10772+
1062610773 auto *Store = cast<StoreSDNode>(N);
1062710774 SDValue Val = Store->getValue();
1062810775 // Combine store of vmv.x.s/vfmv.f.s to vse with VL of 1.
@@ -13452,6 +13599,11 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
1345213599 NODE_NAME_CASE(ORC_B)
1345313600 NODE_NAME_CASE(ZIP)
1345413601 NODE_NAME_CASE(UNZIP)
13602+ NODE_NAME_CASE(TH_LWD)
13603+ NODE_NAME_CASE(TH_LWUD)
13604+ NODE_NAME_CASE(TH_LDD)
13605+ NODE_NAME_CASE(TH_SWD)
13606+ NODE_NAME_CASE(TH_SDD)
1345513607 NODE_NAME_CASE(VMV_V_X_VL)
1345613608 NODE_NAME_CASE(VFMV_V_F_VL)
1345713609 NODE_NAME_CASE(VMV_X_S)
0 commit comments