Skip to content

Commit f3a9dbe

Browse files
authored
[RISCV] Split build_vector into vreg sized pieces when exact VLEN is known (#73606)
If we have a high LMUL build_vector and a known exact VLEN, we can decompose the build_vector into one build_vector per register in the register group. Doing so requires exact knowledge of which elements correspond to each register in the register group, and thus an exact VLEN must be known. Since we no longer have operations which are linear (or worse) in LMUL, this also allows us to lower all build_vectors without resorting to going through the stack.
1 parent b6eb740 commit f3a9dbe

File tree

3 files changed

+270
-421
lines changed

3 files changed

+270
-421
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3105,6 +3105,14 @@ getVSlideup(SelectionDAG &DAG, const RISCVSubtarget &Subtarget, const SDLoc &DL,
31053105
return DAG.getNode(RISCVISD::VSLIDEUP_VL, DL, VT, Ops);
31063106
}
31073107

3108+
static MVT getLMUL1VT(MVT VT) {
3109+
assert(VT.getVectorElementType().getSizeInBits() <= 64 &&
3110+
"Unexpected vector MVT");
3111+
return MVT::getScalableVectorVT(
3112+
VT.getVectorElementType(),
3113+
RISCV::RVVBitsPerBlock / VT.getVectorElementType().getSizeInBits());
3114+
}
3115+
31083116
struct VIDSequence {
31093117
int64_t StepNumerator;
31103118
unsigned StepDenominator;
@@ -3750,6 +3758,37 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
37503758
if (SDValue Res = lowerBuildVectorViaDominantValues(Op, DAG, Subtarget))
37513759
return Res;
37523760

3761+
// If we're compiling for an exact VLEN value, we can split our work per
3762+
// register in the register group.
3763+
const unsigned MinVLen = Subtarget.getRealMinVLen();
3764+
const unsigned MaxVLen = Subtarget.getRealMaxVLen();
3765+
if (MinVLen == MaxVLen && VT.getSizeInBits().getKnownMinValue() > MinVLen) {
3766+
MVT ElemVT = VT.getVectorElementType();
3767+
unsigned ElemsPerVReg = MinVLen / ElemVT.getFixedSizeInBits();
3768+
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
3769+
MVT OneRegVT = MVT::getVectorVT(ElemVT, ElemsPerVReg);
3770+
MVT M1VT = getContainerForFixedLengthVector(DAG, OneRegVT, Subtarget);
3771+
assert(M1VT == getLMUL1VT(M1VT));
3772+
3773+
// The following semantically builds up a fixed length concat_vector
3774+
// of the component build_vectors. We eagerly lower to scalable and
3775+
// insert_subvector here to avoid DAG combining it back to a large
3776+
// build_vector.
3777+
SmallVector<SDValue> BuildVectorOps(Op->op_begin(), Op->op_end());
3778+
unsigned NumOpElts = M1VT.getVectorMinNumElements();
3779+
SDValue Vec = DAG.getUNDEF(ContainerVT);
3780+
for (unsigned i = 0; i < VT.getVectorNumElements(); i += ElemsPerVReg) {
3781+
auto OneVRegOfOps = ArrayRef(BuildVectorOps).slice(i, ElemsPerVReg);
3782+
SDValue SubBV =
3783+
DAG.getNode(ISD::BUILD_VECTOR, DL, OneRegVT, OneVRegOfOps);
3784+
SubBV = convertToScalableVector(M1VT, SubBV, DAG, Subtarget);
3785+
unsigned InsertIdx = (i / ElemsPerVReg) * NumOpElts;
3786+
Vec = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, ContainerVT, Vec, SubBV,
3787+
DAG.getVectorIdxConstant(InsertIdx, DL));
3788+
}
3789+
return convertFromScalableVector(VT, Vec, DAG, Subtarget);
3790+
}
3791+
37533792
// Cap the cost at a value linear to the number of elements in the vector.
37543793
// The default lowering is to use the stack. The vector store + scalar loads
37553794
// is linear in VL. However, at high lmuls vslide1down and vslidedown end up
@@ -3944,14 +3983,6 @@ static SDValue lowerScalarSplat(SDValue Passthru, SDValue Scalar, SDValue VL,
39443983
return splatSplitI64WithVL(DL, VT, Passthru, Scalar, VL, DAG);
39453984
}
39463985

3947-
static MVT getLMUL1VT(MVT VT) {
3948-
assert(VT.getVectorElementType().getSizeInBits() <= 64 &&
3949-
"Unexpected vector MVT");
3950-
return MVT::getScalableVectorVT(
3951-
VT.getVectorElementType(),
3952-
RISCV::RVVBitsPerBlock / VT.getVectorElementType().getSizeInBits());
3953-
}
3954-
39553986
// This function lowers an insert of a scalar operand Scalar into lane
39563987
// 0 of the vector regardless of the value of VL. The contents of the
39573988
// remaining lanes of the result vector are unspecified. VL is assumed

0 commit comments

Comments
 (0)