Skip to content

Commit

Permalink
[RISCV] Remove RISCVISD::SPLAT_VECTOR_I64 in favor of RISCVISD::VMV_V…
Browse files Browse the repository at this point in the history
…_X_VL.

SPLAT_VECTOR_I64 has the same semantics as RISCVISD::VMV_V_X_VL, it
just assumed VLMax instead of carrying a VL operand.

Include order of RISCVInstrInfoVSDPatterns.td and RISCVInstrInfoVVLPatterns.td
has been swapped to avoid moving riscv_vmv_v_x_vl into
RISCVInstrInfoVSDPatterns.td and to allow moving other "_vl" SDNodes back to
RISCVInstrInfoVVLPatterns.td

Reviewed By: frasercrmck

Differential Revision: https://reviews.llvm.org/D118841
  • Loading branch information
topperc committed Feb 3, 2022
1 parent bad0301 commit 2349fb0
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 45 deletions.
13 changes: 5 additions & 8 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1885,7 +1885,6 @@ bool RISCVDAGToDAGISel::selectVLOp(SDValue N, SDValue &VL) {

bool RISCVDAGToDAGISel::selectVSplat(SDValue N, SDValue &SplatVal) {
if (N.getOpcode() != ISD::SPLAT_VECTOR &&
N.getOpcode() != RISCVISD::SPLAT_VECTOR_I64 &&
N.getOpcode() != RISCVISD::VMV_V_X_VL)
return false;
SplatVal = N.getOperand(0);
Expand All @@ -1899,18 +1898,17 @@ static bool selectVSplatSimmHelper(SDValue N, SDValue &SplatVal,
const RISCVSubtarget &Subtarget,
ValidateFn ValidateImm) {
if ((N.getOpcode() != ISD::SPLAT_VECTOR &&
N.getOpcode() != RISCVISD::SPLAT_VECTOR_I64 &&
N.getOpcode() != RISCVISD::VMV_V_X_VL) ||
!isa<ConstantSDNode>(N.getOperand(0)))
return false;

int64_t SplatImm = cast<ConstantSDNode>(N.getOperand(0))->getSExtValue();

// ISD::SPLAT_VECTOR, RISCVISD::SPLAT_VECTOR_I64 and RISCVISD::VMV_V_X_VL
// share semantics when the operand type is wider than the resulting vector
// element type: an implicit truncation first takes place. Therefore, perform
// a manual truncation/sign-extension in order to ignore any truncated bits
// and catch any zero-extended immediate.
// ISD::SPLAT_VECTOR, RISCVISD::VMV_V_X_VL share semantics when the operand
// type is wider than the resulting vector element type: an implicit
// truncation first takes place. Therefore, perform a manual
// truncation/sign-extension in order to ignore any truncated bits and catch
// any zero-extended immediate.
// For example, we wish to match (i8 -1) -> (XLenVT 255) as a simm5 by first
// sign-extending to (XLenVT -1).
MVT XLenVT = Subtarget.getXLenVT();
Expand Down Expand Up @@ -1948,7 +1946,6 @@ bool RISCVDAGToDAGISel::selectVSplatSimm5Plus1NonZero(SDValue N,

bool RISCVDAGToDAGISel::selectVSplatUimm5(SDValue N, SDValue &SplatVal) {
if ((N.getOpcode() != ISD::SPLAT_VECTOR &&
N.getOpcode() != RISCVISD::SPLAT_VECTOR_I64 &&
N.getOpcode() != RISCVISD::VMV_V_X_VL) ||
!isa<ConstantSDNode>(N.getOperand(0)))
return false;
Expand Down
29 changes: 19 additions & 10 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4025,7 +4025,7 @@ SDValue RISCVTargetLowering::lowerVectorMaskSplat(SDValue Op,
// Custom-lower a SPLAT_VECTOR_PARTS where XLEN<SEW, as the SEW element type is
// illegal (currently only vXi64 RV32).
// FIXME: We could also catch non-constant sign-extended i32 values and lower
// them to SPLAT_VECTOR_I64
// them to VMV_V_X_VL.
SDValue RISCVTargetLowering::lowerSPLAT_VECTOR_PARTS(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
Expand Down Expand Up @@ -4054,14 +4054,18 @@ SDValue RISCVTargetLowering::lowerSPLAT_VECTOR_PARTS(SDValue Op,
// If Hi constant is all the same sign bit as Lo, lower this as a custom
// node in order to try and match RVV vector/scalar instructions.
if ((LoC >> 31) == HiC)
return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Lo);
return DAG.getNode(
RISCVISD::VMV_V_X_VL, DL, VecVT, Lo,
DAG.getTargetConstant(RISCV::VLMaxSentinel, DL, MVT::i32));
}

// Detect cases where Hi is (SRA Lo, 31) which means Hi is Lo sign extended.
if (Hi.getOpcode() == ISD::SRA && Hi.getOperand(0) == Lo &&
isa<ConstantSDNode>(Hi.getOperand(1)) &&
Hi.getConstantOperandVal(1) == 31)
return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Lo);
return DAG.getNode(
RISCVISD::VMV_V_X_VL, DL, VecVT, Lo,
DAG.getTargetConstant(RISCV::VLMaxSentinel, DL, MVT::i32));

// Fall back to use a stack store and stride x0 vector load. Use X0 as VL.
return DAG.getNode(RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL, DL, VecVT, Lo, Hi,
Expand Down Expand Up @@ -4089,17 +4093,20 @@ SDValue RISCVTargetLowering::lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG,
// Be careful not to introduce illegal scalar types at this stage, and be
// careful also about splatting constants as on RV32, vXi64 SPLAT_VECTOR is
// illegal and must be expanded. Since we know that the constants are
// sign-extended 32-bit values, we use SPLAT_VECTOR_I64 directly.
// sign-extended 32-bit values, we use VMV_V_X_VL directly.
bool IsRV32E64 =
!Subtarget.is64Bit() && VecVT.getVectorElementType() == MVT::i64;

if (!IsRV32E64) {
SplatZero = DAG.getSplatVector(VecVT, DL, SplatZero);
SplatTrueVal = DAG.getSplatVector(VecVT, DL, SplatTrueVal);
} else {
SplatZero = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, SplatZero);
SplatZero =
DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VecVT, SplatZero,
DAG.getTargetConstant(RISCV::VLMaxSentinel, DL, XLenVT));
SplatTrueVal =
DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, SplatTrueVal);
DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VecVT, SplatTrueVal,
DAG.getTargetConstant(RISCV::VLMaxSentinel, DL, XLenVT));
}

return DAG.getNode(ISD::VSELECT, DL, VecVT, Src, SplatTrueVal, SplatZero);
Expand Down Expand Up @@ -5416,7 +5423,9 @@ SDValue RISCVTargetLowering::lowerVECTOR_REVERSE(SDValue Op,
if (!IsRV32E64)
SplatVL = DAG.getSplatVector(IntVT, DL, VLMinus1);
else
SplatVL = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, IntVT, VLMinus1);
SplatVL =
DAG.getNode(RISCVISD::VMV_V_X_VL, DL, IntVT, VLMinus1,
DAG.getTargetConstant(RISCV::VLMaxSentinel, DL, XLenVT));

SDValue VID = DAG.getNode(RISCVISD::VID_VL, DL, IntVT, Mask, VL);
SDValue Indices =
Expand Down Expand Up @@ -8134,8 +8143,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
// We don't need the upper 32 bits of a 64-bit element for a shift amount.
SDLoc DL(N);
EVT VT = N->getValueType(0);
ShAmt =
DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VT, ShAmt.getOperand(0));
ShAmt = DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, ShAmt.getOperand(0),
DAG.getTargetConstant(RISCV::VLMaxSentinel, DL,
Subtarget.getXLenVT()));
return DAG.getNode(N->getOpcode(), DL, VT, N->getOperand(0), ShAmt);
}
break;
Expand Down Expand Up @@ -10290,7 +10300,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(VMV_X_S)
NODE_NAME_CASE(VMV_S_X_VL)
NODE_NAME_CASE(VFMV_S_F_VL)
NODE_NAME_CASE(SPLAT_VECTOR_I64)
NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL)
NODE_NAME_CASE(READ_VLENB)
NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
Expand Down
3 changes: 0 additions & 3 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ enum NodeType : unsigned {
VMV_S_X_VL,
// VFMV_S_F_VL matches the semantics of vfmv.s.f. It carries a VL operand.
VFMV_S_F_VL,
// Splats an i64 scalar to a vector type (with element type i64) where the
// scalar is a sign-extended i32.
SPLAT_VECTOR_I64,
// Splats an 64-bit value that has been split into two i32 parts. This is
// expanded late to two scalar stores and a stride 0 vector load.
SPLAT_VECTOR_SPLIT_I64_VL,
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
Original file line number Diff line number Diff line change
Expand Up @@ -5097,5 +5097,5 @@ let Predicates = [HasVInstructionsAnyF] in {
} // Predicates = [HasVInstructionsAnyF]

// Include the non-intrinsic ISel patterns
include "RISCVInstrInfoVSDPatterns.td"
include "RISCVInstrInfoVVLPatterns.td"
include "RISCVInstrInfoVSDPatterns.td"
24 changes: 1 addition & 23 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,9 @@
// Helpers to define the SDNode patterns.
//===----------------------------------------------------------------------===//

def SDTSplatI64 : SDTypeProfile<1, 1, [
SDTCVecEltisVT<0, i64>, SDTCisVT<1, i32>
]>;

def rv32_splat_i64 : SDNode<"RISCVISD::SPLAT_VECTOR_I64", SDTSplatI64>;

def SDT_RISCVVMSETCLR_VL : SDTypeProfile<1, 1, [SDTCVecEltisVT<0, i1>,
SDTCisVT<1, XLenVT>]>;
def riscv_vmclr_vl : SDNode<"RISCVISD::VMCLR_VL", SDT_RISCVVMSETCLR_VL>;
def riscv_vmset_vl : SDNode<"RISCVISD::VMSET_VL", SDT_RISCVVMSETCLR_VL>;

def rvv_vnot : PatFrag<(ops node:$in),
(xor node:$in, (riscv_vmset_vl (XLenVT srcvalue)))>;

// Give explicit Complexity to prefer simm5/uimm5.
def SplatPat : ComplexPattern<vAny, 1, "selectVSplat", [splat_vector, rv32_splat_i64], [], 1>;
def SplatPat_simm5 : ComplexPattern<vAny, 1, "selectVSplatSimm5", [splat_vector, rv32_splat_i64], [], 2>;
def SplatPat_uimm5 : ComplexPattern<vAny, 1, "selectVSplatUimm5", [splat_vector, rv32_splat_i64], [], 2>;
def SplatPat_simm5_plus1
: ComplexPattern<vAny, 1, "selectVSplatSimm5Plus1",
[splat_vector, rv32_splat_i64], [], 2>;
def SplatPat_simm5_plus1_nonzero
: ComplexPattern<vAny, 1, "selectVSplatSimm5Plus1NonZero",
[splat_vector, rv32_splat_i64], [], 2>;

class SwapHelper<dag Prefix, dag A, dag B, dag Suffix, bit swap> {
dag Value = !con(Prefix, !if(swap, B, A), !if(swap, A, B), Suffix);
}
Expand Down Expand Up @@ -526,7 +504,7 @@ foreach vti = AllIntegerVectors in {
}
foreach vti = [VI64M1, VI64M2, VI64M4, VI64M8] in {
def : Pat<(shl (vti.Vector vti.RegClass:$rs1),
(vti.Vector (rv32_splat_i64 (XLenVT 1)))),
(vti.Vector (riscv_vmv_v_x_vl 1, (XLenVT srcvalue)))),
(!cast<Instruction>("PseudoVADD_VV_"# vti.LMul.MX)
vti.RegClass:$rs1, vti.RegClass:$rs1, vti.AVL, vti.Log2SEW)>;

Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ def SDT_RISCVSelect_VL : SDTypeProfile<1, 4, [
def riscv_vselect_vl : SDNode<"RISCVISD::VSELECT_VL", SDT_RISCVSelect_VL>;
def riscv_vp_merge_vl : SDNode<"RISCVISD::VP_MERGE_VL", SDT_RISCVSelect_VL>;

def SDT_RISCVVMSETCLR_VL : SDTypeProfile<1, 1, [SDTCVecEltisVT<0, i1>,
SDTCisVT<1, XLenVT>]>;
def riscv_vmclr_vl : SDNode<"RISCVISD::VMCLR_VL", SDT_RISCVVMSETCLR_VL>;
def riscv_vmset_vl : SDNode<"RISCVISD::VMSET_VL", SDT_RISCVVMSETCLR_VL>;

def SDT_RISCVMaskBinOp_VL : SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>,
SDTCisSameAs<0, 2>,
SDTCVecEltisVT<0, i1>,
Expand Down Expand Up @@ -273,6 +278,17 @@ foreach kind = ["ADD", "UMAX", "SMAX", "UMIN", "SMIN", "AND", "OR", "XOR",
"FADD", "SEQ_FADD", "FMIN", "FMAX"] in
def rvv_vecreduce_#kind#_vl : SDNode<"RISCVISD::VECREDUCE_"#kind#"_VL", SDTRVVVecReduce>;

// Give explicit Complexity to prefer simm5/uimm5.
def SplatPat : ComplexPattern<vAny, 1, "selectVSplat", [splat_vector], [], 1>;
def SplatPat_simm5 : ComplexPattern<vAny, 1, "selectVSplatSimm5", [splat_vector], [], 2>;
def SplatPat_uimm5 : ComplexPattern<vAny, 1, "selectVSplatUimm5", [splat_vector], [], 2>;
def SplatPat_simm5_plus1
: ComplexPattern<vAny, 1, "selectVSplatSimm5Plus1",
[splat_vector], [], 2>;
def SplatPat_simm5_plus1_nonzero
: ComplexPattern<vAny, 1, "selectVSplatSimm5Plus1NonZero",
[splat_vector], [], 2>;

// Ignore the vl operand.
def SplatFPOp : PatFrag<(ops node:$op),
(riscv_vfmv_v_f_vl node:$op, srcvalue)>;
Expand Down

0 comments on commit 2349fb0

Please sign in to comment.