From 4237e3f7ed326eff08a82d7b04ec00a5b4336294 Mon Sep 17 00:00:00 2001 From: Craig Topper Date: Tue, 23 Jul 2024 14:57:31 -0700 Subject: [PATCH] [RISCV] Replace VNCLIP RISCVISD opcodes with TRUNCATE_VECTOR_VL_SSAT/USAT opcodes (#100173) Summary: These new opcodes drop the shift amount, rounding mode, and passthru. Making them exactly like TRUNCATE_VECTOR_VL. The shift amount, rounding mode, and passthru are added in isel patterns similar to how we translate TRUNCATE_VECTOR_VL to vnsrl with a shift of 0. This should simplify #99418 a little. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60251265 --- llvm/lib/Target/RISCV/RISCVISelLowering.cpp | 25 ++-- llvm/lib/Target/RISCV/RISCVISelLowering.h | 10 +- .../Target/RISCV/RISCVInstrInfoVVLPatterns.td | 115 +++++------------- 3 files changed, 44 insertions(+), 106 deletions(-) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 22cdfdcfd80d9d..dd7b0b4ed5ef77 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -2997,13 +2997,9 @@ static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG, CvtEltVT = MVT::getIntegerVT(CvtEltVT.getSizeInBits() / 2); CvtContainerVT = CvtContainerVT.changeVectorElementType(CvtEltVT); // Rounding mode here is arbitrary since we aren't shifting out any bits. - unsigned ClipOpc = IsSigned ? RISCVISD::VNCLIP_VL : RISCVISD::VNCLIPU_VL; - Res = DAG.getNode( - ClipOpc, DL, CvtContainerVT, - {Res, DAG.getConstant(0, DL, CvtContainerVT), - DAG.getUNDEF(CvtContainerVT), Mask, - DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()), - VL}); + unsigned ClipOpc = IsSigned ? RISCVISD::TRUNCATE_VECTOR_VL_SSAT + : RISCVISD::TRUNCATE_VECTOR_VL_USAT; + Res = DAG.getNode(ClipOpc, DL, CvtContainerVT, Res, Mask, VL); } SDValue SplatZero = DAG.getNode( @@ -16643,9 +16639,9 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG, SDValue Val; unsigned ClipOpc; if ((Val = DetectUSatPattern(Src))) - ClipOpc = RISCVISD::VNCLIPU_VL; + ClipOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT; else if ((Val = DetectSSatPattern(Src))) - ClipOpc = RISCVISD::VNCLIP_VL; + ClipOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT; else return SDValue(); @@ -16654,12 +16650,7 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG, do { MVT ValEltVT = MVT::getIntegerVT(ValVT.getScalarSizeInBits() / 2); ValVT = ValVT.changeVectorElementType(ValEltVT); - // Rounding mode here is arbitrary since we aren't shifting out any bits. - Val = DAG.getNode( - ClipOpc, DL, ValVT, - {Val, DAG.getConstant(0, DL, ValVT), DAG.getUNDEF(VT), Mask, - DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()), - VL}); + Val = DAG.getNode(ClipOpc, DL, ValVT, Val, Mask, VL); } while (ValVT != VT); return Val; @@ -20463,6 +20454,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL) NODE_NAME_CASE(READ_VLENB) NODE_NAME_CASE(TRUNCATE_VECTOR_VL) + NODE_NAME_CASE(TRUNCATE_VECTOR_VL_SSAT) + NODE_NAME_CASE(TRUNCATE_VECTOR_VL_USAT) NODE_NAME_CASE(VSLIDEUP_VL) NODE_NAME_CASE(VSLIDE1UP_VL) NODE_NAME_CASE(VSLIDEDOWN_VL) @@ -20506,8 +20499,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const { NODE_NAME_CASE(UADDSAT_VL) NODE_NAME_CASE(SSUBSAT_VL) NODE_NAME_CASE(USUBSAT_VL) - NODE_NAME_CASE(VNCLIP_VL) - NODE_NAME_CASE(VNCLIPU_VL) NODE_NAME_CASE(FADD_VL) NODE_NAME_CASE(FSUB_VL) NODE_NAME_CASE(FMUL_VL) diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index 0b0ad9229f0b35..e469a4b1238c74 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -181,6 +181,12 @@ enum NodeType : unsigned { // Truncates a RVV integer vector by one power-of-two. Carries both an extra // mask and VL operand. TRUNCATE_VECTOR_VL, + // Truncates a RVV integer vector by one power-of-two. If the value doesn't + // fit in the destination type, the result is saturated. These correspond to + // vnclip and vnclipu with a shift of 0. Carries both an extra mask and VL + // operand. + TRUNCATE_VECTOR_VL_SSAT, + TRUNCATE_VECTOR_VL_USAT, // Matches the semantics of vslideup/vslidedown. The first operand is the // pass-thru operand, the second is the source vector, the third is the XLenVT // index (either constant or non-constant), the fourth is the mask, the fifth @@ -273,10 +279,6 @@ enum NodeType : unsigned { // Rounding averaging adds of unsigned integers. AVGCEILU_VL, - // Operands are (source, shift, merge, mask, roundmode, vl) - VNCLIPU_VL, - VNCLIP_VL, - MULHS_VL, MULHU_VL, FADD_VL, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td index cc294bf9254e81..2ed71f6b88974b 100644 --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -132,9 +132,6 @@ def riscv_uaddsat_vl : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [S def riscv_ssubsat_vl : SDNode<"RISCVISD::SSUBSAT_VL", SDT_RISCVIntBinOp_VL>; def riscv_usubsat_vl : SDNode<"RISCVISD::USUBSAT_VL", SDT_RISCVIntBinOp_VL>; -def riscv_vnclipu_vl : SDNode<"RISCVISD::VNCLIPU_VL", SDT_RISCVVNBinOp_RM_VL>; -def riscv_vnclip_vl : SDNode<"RISCVISD::VNCLIP_VL", SDT_RISCVVNBinOp_RM_VL>; - def riscv_fadd_vl : SDNode<"RISCVISD::FADD_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>; def riscv_fsub_vl : SDNode<"RISCVISD::FSUB_VL", SDT_RISCVFPBinOp_VL>; def riscv_fmul_vl : SDNode<"RISCVISD::FMUL_VL", SDT_RISCVFPBinOp_VL, [SDNPCommutative]>; @@ -408,12 +405,17 @@ def riscv_ext_vl : PatFrags<(ops node:$A, node:$B, node:$C), [(riscv_sext_vl node:$A, node:$B, node:$C), (riscv_zext_vl node:$A, node:$B, node:$C)]>; +def SDT_RISCVVTRUNCATE_VL : SDTypeProfile<1, 3, [SDTCisVec<0>, + SDTCisSameNumEltsAs<0, 1>, + SDTCisSameNumEltsAs<0, 2>, + SDTCVecEltisVT<2, i1>, + SDTCisVT<3, XLenVT>]>; def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL", - SDTypeProfile<1, 3, [SDTCisVec<0>, - SDTCisSameNumEltsAs<0, 1>, - SDTCisSameNumEltsAs<0, 2>, - SDTCVecEltisVT<2, i1>, - SDTCisVT<3, XLenVT>]>>; + SDT_RISCVVTRUNCATE_VL>; +def riscv_trunc_vector_vl_ssat : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL_SSAT", + SDT_RISCVVTRUNCATE_VL>; +def riscv_trunc_vector_vl_usat : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL_USAT", + SDT_RISCVVTRUNCATE_VL>; def SDT_RISCVVWIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>, SDTCisInt<1>, @@ -650,34 +652,6 @@ class VPatBinaryVL_V; -multiclass VPatBinaryRM_VL_V { - def : Pat<(result_type (vop - (op1_type op1_reg_class:$rs1), - (op2_type op2_reg_class:$rs2), - (result_type result_reg_class:$merge), - (mask_type V0), - (XLenVT timm:$roundmode), - VLOpFrag)), - (!cast(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK") - result_reg_class:$merge, - op1_reg_class:$rs1, - op2_reg_class:$rs2, - (mask_type V0), - (XLenVT timm:$roundmode), - GPR:$vl, sew, TAIL_AGNOSTIC)>; -} - class VPatBinaryVL_V_RM; -multiclass VPatBinaryRM_VL_XI { - def : Pat<(result_type (vop - (vop1_type vop_reg_class:$rs1), - (vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))), - (result_type result_reg_class:$merge), - (mask_type V0), - (XLenVT timm:$roundmode), - VLOpFrag)), - (!cast(instruction_name#_#suffix#_# vlmul.MX#"_MASK") - result_reg_class:$merge, - vop_reg_class:$rs1, - xop_kind:$rs2, - (mask_type V0), - (XLenVT timm:$roundmode), - GPR:$vl, sew, TAIL_AGNOSTIC)>; -} - multiclass VPatBinaryVL_VV_VX vtilist = AllIntegerVectors, bit isSEWAware = 0> { @@ -965,24 +910,6 @@ multiclass VPatBinaryNVL_WV_WX_WI { - foreach VtiToWti = AllWidenableIntVectors in { - defvar vti = VtiToWti.Vti; - defvar wti = VtiToWti.Wti; - defm : VPatBinaryRM_VL_V; - defm : VPatBinaryRM_VL_XI; - defm : VPatBinaryRM_VL_XI(SplatPat#_#uimm5), - uimm5>; - } -} - class VPatBinaryVL_VF; defm : VPatAVGADDVL_VV_VX_RM; // 12.5. Vector Narrowing Fixed-Point Clip Instructions -defm : VPatBinaryRM_NVL_WV_WX_WI; -defm : VPatBinaryRM_NVL_WV_WX_WI; +foreach vtiTowti = AllWidenableIntVectors in { + defvar vti = vtiTowti.Vti; + defvar wti = vtiTowti.Wti; + let Predicates = !listconcat(GetVTypePredicates.Predicates, + GetVTypePredicates.Predicates) in { + // Rounding mode here is arbitrary since we aren't shifting out any bits. + def : Pat<(vti.Vector (riscv_trunc_vector_vl_ssat (wti.Vector wti.RegClass:$rs1), + (vti.Mask V0), + VLOpFrag)), + (!cast("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK") + (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0, + (vti.Mask V0), /*RNU*/0, GPR:$vl, vti.Log2SEW, TA_MA)>; + def : Pat<(vti.Vector (riscv_trunc_vector_vl_usat (wti.Vector wti.RegClass:$rs1), + (vti.Mask V0), + VLOpFrag)), + (!cast("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK") + (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0, + (vti.Mask V0), /*RNU*/0, GPR:$vl, vti.Log2SEW, TA_MA)>; + } +} // 13. Vector Floating-Point Instructions