Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Replace VNCLIP RISCVISD opcodes with TRUNCATE_VECTOR_VL_SSAT/USAT opcodes #100173

Merged
merged 1 commit into from
Jul 23, 2024

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented Jul 23, 2024

These new opcodes drop the shift amount, rounding mode, and passthru. Making them exactly like TRUNCATE_VECTOR_VL. The shift amount, rounding mode, and passthru are added in isel patterns similar to how we translate TRUNCATE_VECTOR_VL to vnsrl with a shift of 0.

This should simplify #99418 a little.

…USAT opcodes.

These new opcodes drop the shift amount, rounding mode, and passthru.
Making them exactly like TRUNCATE_VECTOR_VL. The shift amount,
rounding mode, and passthru are added in isel patterns similar to
how we translate TRUNCATE_VECTOR_VL to vnsrl with a shift of 0.

This should simplify llvm#99418 a little.
@llvmbot
Copy link
Collaborator

llvmbot commented Jul 23, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

Changes

…USAT opcodes.

These new opcodes drop the shift amount, rounding mode, and passthru. Making them exactly like TRUNCATE_VECTOR_VL. The shift amount, rounding mode, and passthru are added in isel patterns similar to how we translate TRUNCATE_VECTOR_VL to vnsrl with a shift of 0.

This should simplify #99418 a little.


Full diff: https://github.com/llvm/llvm-project/pull/100173.diff

3 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+8-17)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+6-4)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td (+30-85)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 22cdfdcfd80d9..dd7b0b4ed5ef7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -2997,13 +2997,9 @@ static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG,
     CvtEltVT = MVT::getIntegerVT(CvtEltVT.getSizeInBits() / 2);
     CvtContainerVT = CvtContainerVT.changeVectorElementType(CvtEltVT);
     // Rounding mode here is arbitrary since we aren't shifting out any bits.
-    unsigned ClipOpc = IsSigned ? RISCVISD::VNCLIP_VL : RISCVISD::VNCLIPU_VL;
-    Res = DAG.getNode(
-        ClipOpc, DL, CvtContainerVT,
-        {Res, DAG.getConstant(0, DL, CvtContainerVT),
-         DAG.getUNDEF(CvtContainerVT), Mask,
-         DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
-         VL});
+    unsigned ClipOpc = IsSigned ? RISCVISD::TRUNCATE_VECTOR_VL_SSAT
+                                : RISCVISD::TRUNCATE_VECTOR_VL_USAT;
+    Res = DAG.getNode(ClipOpc, DL, CvtContainerVT, Res, Mask, VL);
   }
 
   SDValue SplatZero = DAG.getNode(
@@ -16643,9 +16639,9 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
   SDValue Val;
   unsigned ClipOpc;
   if ((Val = DetectUSatPattern(Src)))
-    ClipOpc = RISCVISD::VNCLIPU_VL;
+    ClipOpc = RISCVISD::TRUNCATE_VECTOR_VL_USAT;
   else if ((Val = DetectSSatPattern(Src)))
-    ClipOpc = RISCVISD::VNCLIP_VL;
+    ClipOpc = RISCVISD::TRUNCATE_VECTOR_VL_SSAT;
   else
     return SDValue();
 
@@ -16654,12 +16650,7 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
   do {
     MVT ValEltVT = MVT::getIntegerVT(ValVT.getScalarSizeInBits() / 2);
     ValVT = ValVT.changeVectorElementType(ValEltVT);
-    // Rounding mode here is arbitrary since we aren't shifting out any bits.
-    Val = DAG.getNode(
-        ClipOpc, DL, ValVT,
-        {Val, DAG.getConstant(0, DL, ValVT), DAG.getUNDEF(VT), Mask,
-         DAG.getTargetConstant(RISCVVXRndMode::RNU, DL, Subtarget.getXLenVT()),
-         VL});
+    Val = DAG.getNode(ClipOpc, DL, ValVT, Val, Mask, VL);
   } while (ValVT != VT);
 
   return Val;
@@ -20463,6 +20454,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(SPLAT_VECTOR_SPLIT_I64_VL)
   NODE_NAME_CASE(READ_VLENB)
   NODE_NAME_CASE(TRUNCATE_VECTOR_VL)
+  NODE_NAME_CASE(TRUNCATE_VECTOR_VL_SSAT)
+  NODE_NAME_CASE(TRUNCATE_VECTOR_VL_USAT)
   NODE_NAME_CASE(VSLIDEUP_VL)
   NODE_NAME_CASE(VSLIDE1UP_VL)
   NODE_NAME_CASE(VSLIDEDOWN_VL)
@@ -20506,8 +20499,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(UADDSAT_VL)
   NODE_NAME_CASE(SSUBSAT_VL)
   NODE_NAME_CASE(USUBSAT_VL)
-  NODE_NAME_CASE(VNCLIP_VL)
-  NODE_NAME_CASE(VNCLIPU_VL)
   NODE_NAME_CASE(FADD_VL)
   NODE_NAME_CASE(FSUB_VL)
   NODE_NAME_CASE(FMUL_VL)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 0b0ad9229f0b3..e469a4b1238c7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -181,6 +181,12 @@ enum NodeType : unsigned {
   // Truncates a RVV integer vector by one power-of-two. Carries both an extra
   // mask and VL operand.
   TRUNCATE_VECTOR_VL,
+  // Truncates a RVV integer vector by one power-of-two. If the value doesn't
+  // fit in the destination type, the result is saturated. These correspond to
+  // vnclip and vnclipu with a shift of 0. Carries both an extra mask and VL
+  // operand.
+  TRUNCATE_VECTOR_VL_SSAT,
+  TRUNCATE_VECTOR_VL_USAT,
   // Matches the semantics of vslideup/vslidedown. The first operand is the
   // pass-thru operand, the second is the source vector, the third is the XLenVT
   // index (either constant or non-constant), the fourth is the mask, the fifth
@@ -273,10 +279,6 @@ enum NodeType : unsigned {
   // Rounding averaging adds of unsigned integers.
   AVGCEILU_VL,
 
-  // Operands are (source, shift, merge, mask, roundmode, vl)
-  VNCLIPU_VL,
-  VNCLIP_VL,
-
   MULHS_VL,
   MULHU_VL,
   FADD_VL,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index cc294bf9254e8..2ed71f6b88974 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -132,9 +132,6 @@ def riscv_uaddsat_vl   : SDNode<"RISCVISD::UADDSAT_VL", SDT_RISCVIntBinOp_VL, [S
 def riscv_ssubsat_vl   : SDNode<"RISCVISD::SSUBSAT_VL", SDT_RISCVIntBinOp_VL>;
 def riscv_usubsat_vl   : SDNode<"RISCVISD::USUBSAT_VL", SDT_RISCVIntBinOp_VL>;
 
-def riscv_vnclipu_vl : SDNode<"RISCVISD::VNCLIPU_VL", SDT_RISCVVNBinOp_RM_VL>;
-def riscv_vnclip_vl : SDNode<"RISCVISD::VNCLIP_VL", SDT_RISCVVNBinOp_RM_VL>;
-
 def riscv_fadd_vl  : SDNode<"RISCVISD::FADD_VL",  SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
 def riscv_fsub_vl  : SDNode<"RISCVISD::FSUB_VL",  SDT_RISCVFPBinOp_VL>;
 def riscv_fmul_vl  : SDNode<"RISCVISD::FMUL_VL",  SDT_RISCVFPBinOp_VL, [SDNPCommutative]>;
@@ -408,12 +405,17 @@ def riscv_ext_vl : PatFrags<(ops node:$A, node:$B, node:$C),
                             [(riscv_sext_vl node:$A, node:$B, node:$C),
                              (riscv_zext_vl node:$A, node:$B, node:$C)]>;
 
+def SDT_RISCVVTRUNCATE_VL : SDTypeProfile<1, 3, [SDTCisVec<0>,
+                                                 SDTCisSameNumEltsAs<0, 1>,
+                                                 SDTCisSameNumEltsAs<0, 2>,
+                                                 SDTCVecEltisVT<2, i1>,
+                                                 SDTCisVT<3, XLenVT>]>;
 def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL",
-                                   SDTypeProfile<1, 3, [SDTCisVec<0>,
-                                                        SDTCisSameNumEltsAs<0, 1>,
-                                                        SDTCisSameNumEltsAs<0, 2>,
-                                                        SDTCVecEltisVT<2, i1>,
-                                                        SDTCisVT<3, XLenVT>]>>;
+                                   SDT_RISCVVTRUNCATE_VL>;
+def riscv_trunc_vector_vl_ssat : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL_SSAT",
+                                        SDT_RISCVVTRUNCATE_VL>;
+def riscv_trunc_vector_vl_usat : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL_USAT",
+                                        SDT_RISCVVTRUNCATE_VL>;
 
 def SDT_RISCVVWIntBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>,
                                                   SDTCisInt<1>,
@@ -650,34 +652,6 @@ class VPatBinaryVL_V<SDPatternOperator vop,
                    op2_reg_class:$rs2,
                    (mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
 
-multiclass VPatBinaryRM_VL_V<SDNode vop,
-                             string instruction_name,
-                             string suffix,
-                             ValueType result_type,
-                             ValueType op1_type,
-                             ValueType op2_type,
-                             ValueType mask_type,
-                             int sew,
-                             LMULInfo vlmul,
-                             VReg result_reg_class,
-                             VReg op1_reg_class,
-                             VReg op2_reg_class> {
-  def : Pat<(result_type (vop
-                         (op1_type op1_reg_class:$rs1),
-                         (op2_type op2_reg_class:$rs2),
-                         (result_type result_reg_class:$merge),
-                         (mask_type V0),
-                         (XLenVT timm:$roundmode),
-                         VLOpFrag)),
-        (!cast<Instruction>(instruction_name#"_"#suffix#"_"# vlmul.MX#"_MASK")
-                     result_reg_class:$merge,
-                     op1_reg_class:$rs1,
-                     op2_reg_class:$rs2,
-                     (mask_type V0),
-                     (XLenVT timm:$roundmode),
-                     GPR:$vl, sew, TAIL_AGNOSTIC)>;
-}
-
 class VPatBinaryVL_V_RM<SDPatternOperator vop,
                         string instruction_name,
                         string suffix,
@@ -838,35 +812,6 @@ class VPatBinaryVL_XI<SDPatternOperator vop,
                    xop_kind:$rs2,
                    (mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>;
 
-multiclass VPatBinaryRM_VL_XI<SDNode vop,
-                              string instruction_name,
-                              string suffix,
-                              ValueType result_type,
-                              ValueType vop1_type,
-                              ValueType vop2_type,
-                              ValueType mask_type,
-                              int sew,
-                              LMULInfo vlmul,
-                              VReg result_reg_class,
-                              VReg vop_reg_class,
-                              ComplexPattern SplatPatKind,
-                              DAGOperand xop_kind> {
-  def : Pat<(result_type (vop
-                     (vop1_type vop_reg_class:$rs1),
-                     (vop2_type (SplatPatKind (XLenVT xop_kind:$rs2))),
-                     (result_type result_reg_class:$merge),
-                     (mask_type V0),
-                     (XLenVT timm:$roundmode),
-                     VLOpFrag)),
-        (!cast<Instruction>(instruction_name#_#suffix#_# vlmul.MX#"_MASK")
-                     result_reg_class:$merge,
-                     vop_reg_class:$rs1,
-                     xop_kind:$rs2,
-                     (mask_type V0),
-                     (XLenVT timm:$roundmode),
-                     GPR:$vl, sew, TAIL_AGNOSTIC)>;
-}
-
 multiclass VPatBinaryVL_VV_VX<SDPatternOperator vop, string instruction_name,
                               list<VTypeInfo> vtilist = AllIntegerVectors,
                               bit isSEWAware = 0> {
@@ -965,24 +910,6 @@ multiclass VPatBinaryNVL_WV_WX_WI<SDPatternOperator vop, string instruction_name
   }
 }
 
-multiclass VPatBinaryRM_NVL_WV_WX_WI<SDNode vop, string instruction_name> {
-  foreach VtiToWti = AllWidenableIntVectors in {
-    defvar vti = VtiToWti.Vti;
-    defvar wti = VtiToWti.Wti;
-    defm : VPatBinaryRM_VL_V<vop, instruction_name, "WV",
-                             vti.Vector, wti.Vector, vti.Vector, vti.Mask,
-                             vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, vti.RegClass>;
-    defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WX",
-                              vti.Vector, wti.Vector, vti.Vector, vti.Mask,
-                              vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass, SplatPat, GPR>;
-    defm : VPatBinaryRM_VL_XI<vop, instruction_name, "WI",
-                              vti.Vector, wti.Vector, vti.Vector, vti.Mask,
-                              vti.Log2SEW, vti.LMul, vti.RegClass, wti.RegClass,
-                              !cast<ComplexPattern>(SplatPat#_#uimm5),
-                              uimm5>;
-  }
-}
-
 class VPatBinaryVL_VF<SDPatternOperator vop,
                       string instruction_name,
                       ValueType result_type,
@@ -2468,8 +2395,26 @@ defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceils_vl, 0b00>;
 defm : VPatAVGADDVL_VV_VX_RM<riscv_avgceilu_vl, 0b00, suffix="U">;
 
 // 12.5. Vector Narrowing Fixed-Point Clip Instructions
-defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclip_vl, "PseudoVNCLIP">;
-defm : VPatBinaryRM_NVL_WV_WX_WI<riscv_vnclipu_vl, "PseudoVNCLIPU">;
+foreach vtiTowti = AllWidenableIntVectors in {
+  defvar vti = vtiTowti.Vti;
+  defvar wti = vtiTowti.Wti;
+  let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
+                               GetVTypePredicates<wti>.Predicates) in {
+    // Rounding mode here is arbitrary since we aren't shifting out any bits.
+    def : Pat<(vti.Vector (riscv_trunc_vector_vl_ssat (wti.Vector wti.RegClass:$rs1),
+                                                      (vti.Mask V0),
+                                                      VLOpFrag)),
+              (!cast<Instruction>("PseudoVNCLIP_WI_"#vti.LMul.MX#"_MASK")
+                  (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
+                  (vti.Mask V0), /*RNU*/0, GPR:$vl, vti.Log2SEW, TA_MA)>;
+    def : Pat<(vti.Vector (riscv_trunc_vector_vl_usat (wti.Vector wti.RegClass:$rs1),
+                                                      (vti.Mask V0),
+                                                      VLOpFrag)),
+              (!cast<Instruction>("PseudoVNCLIPU_WI_"#vti.LMul.MX#"_MASK")
+                  (vti.Vector (IMPLICIT_DEF)), wti.RegClass:$rs1, 0,
+                  (vti.Mask V0), /*RNU*/0, GPR:$vl, vti.Log2SEW, TA_MA)>;
+  }
+}
 
 // 13. Vector Floating-Point Instructions
 

@dtcxzyw dtcxzyw changed the title [RISCV] Replace VNCLIP RISCVISD opcodes with TRUNCATE_VECTOR_VL_SSAT/… [RISCV] Replace VNCLIP RISCVISD opcodes with TRUNCATE_VECTOR_VL_SSAT/USAT opcodes Jul 23, 2024
Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@topperc topperc merged commit caaba2a into llvm:main Jul 23, 2024
7 of 9 checks passed
@topperc topperc deleted the pr/vnclip branch July 23, 2024 21:57
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
…USAT opcodes (#100173)

Summary:
These new opcodes drop the shift amount, rounding mode, and passthru.
Making them exactly like TRUNCATE_VECTOR_VL. The shift amount, rounding
mode, and passthru are added in isel patterns similar to how we
translate TRUNCATE_VECTOR_VL to vnsrl with a shift of 0.

This should simplify #99418 a little.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60251265
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants