@@ -9723,6 +9723,21 @@ SDValue RISCVTargetLowering::lowerVPREDUCE(SDValue Op,
9723
9723
Vec, Mask, VL, DL, DAG, Subtarget);
9724
9724
}
9725
9725
9726
+ /// Returns true if \p LHS is known to be equal to \p RHS, taking into account
9727
+ /// if VLEN is exactly known by \p Subtarget and thus vscale when handling
9728
+ /// scalable quantities.
9729
+ static bool isKnownEQ(ElementCount LHS, ElementCount RHS,
9730
+ const RISCVSubtarget &Subtarget) {
9731
+ if (auto VLen = Subtarget.getRealVLen()) {
9732
+ const unsigned Vscale = *VLen / RISCV::RVVBitsPerBlock;
9733
+ if (LHS.isScalable())
9734
+ LHS = ElementCount::getFixed(LHS.getKnownMinValue() * Vscale);
9735
+ if (RHS.isScalable())
9736
+ RHS = ElementCount::getFixed(RHS.getKnownMinValue() * Vscale);
9737
+ }
9738
+ return LHS == RHS;
9739
+ }
9740
+
9726
9741
SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
9727
9742
SelectionDAG &DAG) const {
9728
9743
SDValue Vec = Op.getOperand(0);
@@ -9772,12 +9787,13 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
9772
9787
}
9773
9788
}
9774
9789
9775
- // If the subvector vector is a fixed-length type, we cannot use subregister
9776
- // manipulation to simplify the codegen; we don't know which register of a
9777
- // LMUL group contains the specific subvector as we only know the minimum
9778
- // register size. Therefore we must slide the vector group up the full
9779
- // amount.
9780
- if (SubVecVT.isFixedLengthVector()) {
9790
+ // If the subvector vector is a fixed-length type and we don't know VLEN
9791
+ // exactly, we cannot use subregister manipulation to simplify the codegen; we
9792
+ // don't know which register of a LMUL group contains the specific subvector
9793
+ // as we only know the minimum register size. Therefore we must slide the
9794
+ // vector group up the full amount.
9795
+ const auto VLen = Subtarget.getRealVLen();
9796
+ if (SubVecVT.isFixedLengthVector() && !VLen) {
9781
9797
if (OrigIdx == 0 && Vec.isUndef() && !VecVT.isFixedLengthVector())
9782
9798
return Op;
9783
9799
MVT ContainerVT = VecVT;
@@ -9825,41 +9841,92 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
9825
9841
return DAG.getBitcast(Op.getValueType(), SubVec);
9826
9842
}
9827
9843
9828
- unsigned SubRegIdx, RemIdx;
9829
- std::tie(SubRegIdx, RemIdx) =
9830
- RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
9831
- VecVT, SubVecVT, OrigIdx, TRI);
9844
+ MVT ContainerVecVT = VecVT;
9845
+ if (VecVT.isFixedLengthVector()) {
9846
+ ContainerVecVT = getContainerForFixedLengthVector(VecVT);
9847
+ Vec = convertToScalableVector(ContainerVecVT, Vec, DAG, Subtarget);
9848
+ }
9832
9849
9833
- RISCVII::VLMUL SubVecLMUL = RISCVTargetLowering::getLMUL(SubVecVT);
9850
+ MVT ContainerSubVecVT = SubVecVT;
9851
+ if (SubVecVT.isFixedLengthVector()) {
9852
+ ContainerSubVecVT = getContainerForFixedLengthVector(SubVecVT);
9853
+ SubVec = convertToScalableVector(ContainerSubVecVT, SubVec, DAG, Subtarget);
9854
+ }
9855
+
9856
+ unsigned SubRegIdx;
9857
+ ElementCount RemIdx;
9858
+ // insert_subvector scales the index by vscale if the subvector is scalable,
9859
+ // and decomposeSubvectorInsertExtractToSubRegs takes this into account. So if
9860
+ // we have a fixed length subvector, we need to adjust the index by 1/vscale.
9861
+ if (SubVecVT.isFixedLengthVector()) {
9862
+ assert(VLen);
9863
+ unsigned Vscale = *VLen / RISCV::RVVBitsPerBlock;
9864
+ auto Decompose =
9865
+ RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
9866
+ ContainerVecVT, ContainerSubVecVT, OrigIdx / Vscale, TRI);
9867
+ SubRegIdx = Decompose.first;
9868
+ RemIdx = ElementCount::getFixed((Decompose.second * Vscale) +
9869
+ (OrigIdx % Vscale));
9870
+ } else {
9871
+ auto Decompose =
9872
+ RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs(
9873
+ ContainerVecVT, ContainerSubVecVT, OrigIdx, TRI);
9874
+ SubRegIdx = Decompose.first;
9875
+ RemIdx = ElementCount::getScalable(Decompose.second);
9876
+ }
9877
+
9878
+ RISCVII::VLMUL SubVecLMUL = RISCVTargetLowering::getLMUL(ContainerSubVecVT);
9834
9879
bool IsSubVecPartReg = SubVecLMUL == RISCVII::VLMUL::LMUL_F2 ||
9835
9880
SubVecLMUL == RISCVII::VLMUL::LMUL_F4 ||
9836
9881
SubVecLMUL == RISCVII::VLMUL::LMUL_F8;
9882
+ bool AlignedToVecReg = !IsSubVecPartReg;
9883
+ if (SubVecVT.isFixedLengthVector())
9884
+ AlignedToVecReg &= SubVecVT.getSizeInBits() ==
9885
+ ContainerSubVecVT.getSizeInBits().getKnownMinValue() *
9886
+ (*VLen / RISCV::RVVBitsPerBlock);
9837
9887
9838
9888
// 1. If the Idx has been completely eliminated and this subvector's size is
9839
9889
// a vector register or a multiple thereof, or the surrounding elements are
9840
9890
// undef, then this is a subvector insert which naturally aligns to a vector
9841
9891
// register. These can easily be handled using subregister manipulation.
9842
- // 2. If the subvector is smaller than a vector register, then the insertion
9843
- // must preserve the undisturbed elements of the register. We do this by
9844
- // lowering to an EXTRACT_SUBVECTOR grabbing the nearest LMUL=1 vector type
9845
- // (which resolves to a subregister copy), performing a VSLIDEUP to place the
9846
- // subvector within the vector register, and an INSERT_SUBVECTOR of that
9892
+ // 2. If the subvector isn't exactly aligned to a vector register group , then
9893
+ // the insertion must preserve the undisturbed elements of the register. We do
9894
+ // this by lowering to an EXTRACT_SUBVECTOR grabbing the nearest LMUL=1 vector
9895
+ // type (which resolves to a subregister copy), performing a VSLIDEUP to place
9896
+ // the subvector within the vector register, and an INSERT_SUBVECTOR of that
9847
9897
// LMUL=1 type back into the larger vector (resolving to another subregister
9848
9898
// operation). See below for how our VSLIDEUP works. We go via a LMUL=1 type
9849
9899
// to avoid allocating a large register group to hold our subvector.
9850
- if (RemIdx == 0 && (!IsSubVecPartReg || Vec.isUndef()))
9900
+ if (RemIdx.isZero() && (AlignedToVecReg || Vec.isUndef())) {
9901
+ if (SubVecVT.isFixedLengthVector()) {
9902
+ // We may get NoSubRegister if inserting at index 0 and the subvec
9903
+ // container is the same as the vector, e.g. vec=v4i32,subvec=v4i32,idx=0
9904
+ if (SubRegIdx == RISCV::NoSubRegister) {
9905
+ assert(OrigIdx == 0);
9906
+ return Op;
9907
+ }
9908
+
9909
+ SDValue Insert =
9910
+ DAG.getTargetInsertSubreg(SubRegIdx, DL, ContainerVecVT, Vec, SubVec);
9911
+ if (VecVT.isFixedLengthVector())
9912
+ Insert = convertFromScalableVector(VecVT, Insert, DAG, Subtarget);
9913
+ return Insert;
9914
+ }
9851
9915
return Op;
9916
+ }
9852
9917
9853
9918
// VSLIDEUP works by leaving elements 0<i<OFFSET undisturbed, elements
9854
9919
// OFFSET<=i<VL set to the "subvector" and vl<=i<VLMAX set to the tail policy
9855
9920
// (in our case undisturbed). This means we can set up a subvector insertion
9856
9921
// where OFFSET is the insertion offset, and the VL is the OFFSET plus the
9857
9922
// size of the subvector.
9858
- MVT InterSubVT = VecVT ;
9923
+ MVT InterSubVT = ContainerVecVT ;
9859
9924
SDValue AlignedExtract = Vec;
9860
- unsigned AlignedIdx = OrigIdx - RemIdx;
9861
- if (VecVT.bitsGT(getLMUL1VT(VecVT))) {
9862
- InterSubVT = getLMUL1VT(VecVT);
9925
+ unsigned AlignedIdx = OrigIdx - RemIdx.getKnownMinValue();
9926
+ if (SubVecVT.isFixedLengthVector())
9927
+ AlignedIdx /= *VLen / RISCV::RVVBitsPerBlock;
9928
+ if (ContainerVecVT.bitsGT(getLMUL1VT(ContainerVecVT))) {
9929
+ InterSubVT = getLMUL1VT(ContainerVecVT);
9863
9930
// Extract a subvector equal to the nearest full vector register type. This
9864
9931
// should resolve to a EXTRACT_SUBREG instruction.
9865
9932
AlignedExtract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InterSubVT, Vec,
@@ -9870,25 +9937,23 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
9870
9937
DAG.getUNDEF(InterSubVT), SubVec,
9871
9938
DAG.getVectorIdxConstant(0, DL));
9872
9939
9873
- auto [Mask, VL] = getDefaultScalableVLOps (VecVT, DL, DAG, Subtarget);
9940
+ auto [Mask, VL] = getDefaultVLOps (VecVT, ContainerVecVT , DL, DAG, Subtarget);
9874
9941
9875
- ElementCount EndIndex =
9876
- ElementCount::getScalable(RemIdx) + SubVecVT.getVectorElementCount();
9877
- VL = computeVLMax(SubVecVT, DL, DAG);
9942
+ ElementCount EndIndex = RemIdx + SubVecVT.getVectorElementCount();
9943
+ VL = DAG.getElementCount(DL, XLenVT, SubVecVT.getVectorElementCount());
9878
9944
9879
9945
// Use tail agnostic policy if we're inserting over InterSubVT's tail.
9880
9946
unsigned Policy = RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED;
9881
- if (EndIndex == InterSubVT.getVectorElementCount())
9947
+ if (isKnownEQ( EndIndex, InterSubVT.getVectorElementCount(), Subtarget ))
9882
9948
Policy = RISCVII::TAIL_AGNOSTIC;
9883
9949
9884
9950
// If we're inserting into the lowest elements, use a tail undisturbed
9885
9951
// vmv.v.v.
9886
- if (RemIdx == 0 ) {
9952
+ if (RemIdx.isZero() ) {
9887
9953
SubVec = DAG.getNode(RISCVISD::VMV_V_V_VL, DL, InterSubVT, AlignedExtract,
9888
9954
SubVec, VL);
9889
9955
} else {
9890
- SDValue SlideupAmt =
9891
- DAG.getVScale(DL, XLenVT, APInt(XLenVT.getSizeInBits(), RemIdx));
9956
+ SDValue SlideupAmt = DAG.getElementCount(DL, XLenVT, RemIdx);
9892
9957
9893
9958
// Construct the vector length corresponding to RemIdx + length(SubVecVT).
9894
9959
VL = DAG.getNode(ISD::ADD, DL, XLenVT, SlideupAmt, VL);
@@ -9899,10 +9964,13 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
9899
9964
9900
9965
// If required, insert this subvector back into the correct vector register.
9901
9966
// This should resolve to an INSERT_SUBREG instruction.
9902
- if (VecVT .bitsGT(InterSubVT))
9903
- SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VecVT , Vec, SubVec,
9967
+ if (ContainerVecVT .bitsGT(InterSubVT))
9968
+ SubVec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVecVT , Vec, SubVec,
9904
9969
DAG.getVectorIdxConstant(AlignedIdx, DL));
9905
9970
9971
+ if (VecVT.isFixedLengthVector())
9972
+ SubVec = convertFromScalableVector(VecVT, SubVec, DAG, Subtarget);
9973
+
9906
9974
// We might have bitcast from a mask type: cast back to the original type if
9907
9975
// required.
9908
9976
return DAG.getBitcast(Op.getSimpleValueType(), SubVec);
0 commit comments