Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Support (truncate (smin (smax X, C1), C2)) for vnclipu in combineTruncToVnclip. #93756

Merged
merged 3 commits into from
May 30, 2024

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented May 30, 2024

If the smax removed all negative numbers, then we can treat the smin like a umin.

If the smin and smax are in the other order we can swap them and use a vnclipu as long as the smax constant is smaller than the smin constant.

This is based on similar code from X86's detectUSatPattern.

@llvmbot
Copy link
Collaborator

llvmbot commented May 30, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

Changes

If the smax removed all negative numbers, then we can treat the smin like a umin.

If the smin and smax are in the other order we can swap them and use a vnclipu as long as the smax constant is smaller than the smin constant.

This is based on similar code from X86's detectUSatPattern.


Full diff: https://github.com/llvm/llvm-project/pull/93756.diff

3 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+33-16)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-trunc-sat-clip.ll (+6-22)
  • (modified) llvm/test/CodeGen/RISCV/rvv/trunc-sat-clip-sdnode.ll (+9-25)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 0e7713509e969..c079bf5a16b94 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16230,17 +16230,35 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
     return SDValue();
   };
 
+  SDLoc DL(N);
+
   auto DetectUSatPattern = [&](SDValue V) {
-    // Src must be a UMIN or UMIN_VL.
-    APInt C;
-    SDValue UMin = MatchMinMax(V, ISD::UMIN, RISCVISD::UMIN_VL, C);
-    if (!UMin)
-      return SDValue();
+    APInt LoC, HiC;
+
+    // Simple case, V is a UMIN.
+    if (SDValue UMin = MatchMinMax(V, ISD::UMIN, RISCVISD::UMIN_VL, HiC))
+      if (HiC.isMask(VT.getScalarSizeInBits()))
+        return UMin;
+
+    // If we have an SMAX that removes negative numbers first, then we can match
+    // SMIN instead of UMIN.
+    if (SDValue SMin = MatchMinMax(V, ISD::SMIN, RISCVISD::SMIN_VL, HiC))
+      if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, RISCVISD::SMAX_VL, LoC))
+        if (LoC.isNonNegative() && HiC.isMask(VT.getScalarSizeInBits()))
+          return SMin;
 
-    if (!C.isMask(VT.getScalarSizeInBits()))
-      return SDValue();
+    // If we have an SMIN before an SMAX and the SMAX constant is less than or
+    // equal to the SMIN constant, we can use vnclipu if we insert a new SMAX
+    // first.
+    if (SDValue SMax = MatchMinMax(V, ISD::SMAX, RISCVISD::SMAX_VL, LoC))
+      if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, RISCVISD::SMIN_VL, HiC))
+        if (LoC.isNonNegative() && HiC.isMask(VT.getScalarSizeInBits()) &&
+            HiC.uge(LoC))
+          return DAG.getNode(RISCVISD::SMAX_VL, DL, V.getValueType(),
+                             SMin, V.getOperand(1), DAG.getUNDEF(V.getValueType()),
+                             Mask, VL);
 
-    return UMin;
+    return SDValue();
   };
 
   auto DetectSSatPattern = [&](SDValue V) {
@@ -16249,15 +16267,15 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
     APInt SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
     APInt SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
 
-    APInt CMin, CMax;
-    if (SDValue SMin = MatchMinMax(V, ISD::SMIN, RISCVISD::SMIN_VL, CMin))
-      if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, RISCVISD::SMAX_VL, CMax))
-        if (CMin == SignedMax && CMax == SignedMin)
+    APInt HiC, LoC;
+    if (SDValue SMin = MatchMinMax(V, ISD::SMIN, RISCVISD::SMIN_VL, HiC))
+      if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, RISCVISD::SMAX_VL, LoC))
+        if (HiC == SignedMax && LoC == SignedMin)
           return SMax;
 
-    if (SDValue SMax = MatchMinMax(V, ISD::SMAX, RISCVISD::SMAX_VL, CMax))
-      if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, RISCVISD::SMIN_VL, CMin))
-        if (CMin == SignedMax && CMax == SignedMin)
+    if (SDValue SMax = MatchMinMax(V, ISD::SMAX, RISCVISD::SMAX_VL, LoC))
+      if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, RISCVISD::SMIN_VL, HiC))
+        if (HiC == SignedMax && LoC == SignedMin)
           return SMin;
 
     return SDValue();
@@ -16272,7 +16290,6 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
   else
     return SDValue();
 
-  SDLoc DL(N);
   // Rounding mode here is arbitrary since we aren't shifting out any bits.
   return DAG.getNode(
       ClipOpc, DL, VT,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-trunc-sat-clip.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-trunc-sat-clip.ll
index 9f82eddf432da..aaab2e5b3d97a 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-trunc-sat-clip.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-trunc-sat-clip.ll
@@ -105,10 +105,8 @@ define void @trunc_sat_u8u16_maxmin(ptr %x, ptr %y) {
 ; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
 ; CHECK-NEXT:    vle16.v v8, (a0)
 ; CHECK-NEXT:    vmax.vx v8, v8, zero
-; CHECK-NEXT:    li a0, 255
-; CHECK-NEXT:    vmin.vx v8, v8, a0
 ; CHECK-NEXT:    vsetvli zero, zero, e8, mf4, ta, ma
-; CHECK-NEXT:    vnsrl.wi v8, v8, 0
+; CHECK-NEXT:    vnclipu.wi v8, v8, 0
 ; CHECK-NEXT:    vse8.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = load <4 x i16>, ptr %x, align 16
@@ -125,11 +123,9 @@ define void @trunc_sat_u8u16_minmax(ptr %x, ptr %y) {
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetivli zero, 4, e16, mf2, ta, ma
 ; CHECK-NEXT:    vle16.v v8, (a0)
-; CHECK-NEXT:    li a0, 255
-; CHECK-NEXT:    vmin.vx v8, v8, a0
 ; CHECK-NEXT:    vmax.vx v8, v8, zero
 ; CHECK-NEXT:    vsetvli zero, zero, e8, mf4, ta, ma
-; CHECK-NEXT:    vnsrl.wi v8, v8, 0
+; CHECK-NEXT:    vnclipu.wi v8, v8, 0
 ; CHECK-NEXT:    vse8.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = load <4 x i16>, ptr %x, align 16
@@ -237,11 +233,8 @@ define void @trunc_sat_u16u32_maxmin(ptr %x, ptr %y) {
 ; CHECK-NEXT:    vle32.v v8, (a0)
 ; CHECK-NEXT:    li a0, 1
 ; CHECK-NEXT:    vmax.vx v8, v8, a0
-; CHECK-NEXT:    lui a0, 16
-; CHECK-NEXT:    addi a0, a0, -1
-; CHECK-NEXT:    vmin.vx v8, v8, a0
 ; CHECK-NEXT:    vsetvli zero, zero, e16, mf2, ta, ma
-; CHECK-NEXT:    vnsrl.wi v8, v8, 0
+; CHECK-NEXT:    vnclipu.wi v8, v8, 0
 ; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = load <4 x i32>, ptr %x, align 16
@@ -258,13 +251,10 @@ define void @trunc_sat_u16u32_minmax(ptr %x, ptr %y) {
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetivli zero, 4, e32, m1, ta, ma
 ; CHECK-NEXT:    vle32.v v8, (a0)
-; CHECK-NEXT:    lui a0, 16
-; CHECK-NEXT:    addi a0, a0, -1
-; CHECK-NEXT:    vmin.vx v8, v8, a0
 ; CHECK-NEXT:    li a0, 50
 ; CHECK-NEXT:    vmax.vx v8, v8, a0
 ; CHECK-NEXT:    vsetvli zero, zero, e16, mf2, ta, ma
-; CHECK-NEXT:    vnsrl.wi v8, v8, 0
+; CHECK-NEXT:    vnclipu.wi v8, v8, 0
 ; CHECK-NEXT:    vse16.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = load <4 x i32>, ptr %x, align 16
@@ -374,11 +364,8 @@ define void @trunc_sat_u32u64_maxmin(ptr %x, ptr %y) {
 ; CHECK-NEXT:    vsetivli zero, 4, e64, m2, ta, ma
 ; CHECK-NEXT:    vle64.v v8, (a0)
 ; CHECK-NEXT:    vmax.vx v8, v8, zero
-; CHECK-NEXT:    li a0, -1
-; CHECK-NEXT:    srli a0, a0, 32
-; CHECK-NEXT:    vmin.vx v8, v8, a0
 ; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vnsrl.wi v10, v8, 0
+; CHECK-NEXT:    vnclipu.wi v10, v8, 0
 ; CHECK-NEXT:    vse32.v v10, (a1)
 ; CHECK-NEXT:    ret
   %1 = load <4 x i64>, ptr %x, align 16
@@ -395,12 +382,9 @@ define void @trunc_sat_u32u64_minmax(ptr %x, ptr %y) {
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vsetivli zero, 4, e64, m2, ta, ma
 ; CHECK-NEXT:    vle64.v v8, (a0)
-; CHECK-NEXT:    li a0, -1
-; CHECK-NEXT:    srli a0, a0, 32
-; CHECK-NEXT:    vmin.vx v8, v8, a0
 ; CHECK-NEXT:    vmax.vx v8, v8, zero
 ; CHECK-NEXT:    vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT:    vnsrl.wi v10, v8, 0
+; CHECK-NEXT:    vnclipu.wi v10, v8, 0
 ; CHECK-NEXT:    vse32.v v10, (a1)
 ; CHECK-NEXT:    ret
   %1 = load <4 x i64>, ptr %x, align 16
diff --git a/llvm/test/CodeGen/RISCV/rvv/trunc-sat-clip-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/trunc-sat-clip-sdnode.ll
index 78e8f0fbbbdd7..4ca99be1edb99 100644
--- a/llvm/test/CodeGen/RISCV/rvv/trunc-sat-clip-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/trunc-sat-clip-sdnode.ll
@@ -105,10 +105,8 @@ define void @trunc_sat_u8u16_maxmin(ptr %x, ptr %y) {
 ; CHECK-NEXT:    vl1re16.v v8, (a0)
 ; CHECK-NEXT:    vsetvli a0, zero, e16, m1, ta, ma
 ; CHECK-NEXT:    vmax.vx v8, v8, zero
-; CHECK-NEXT:    li a0, 255
-; CHECK-NEXT:    vmin.vx v8, v8, a0
 ; CHECK-NEXT:    vsetvli zero, zero, e8, mf2, ta, ma
-; CHECK-NEXT:    vnsrl.wi v8, v8, 0
+; CHECK-NEXT:    vnclipu.wi v8, v8, 0
 ; CHECK-NEXT:    vse8.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = load <vscale x 4 x i16>, ptr %x, align 16
@@ -124,12 +122,10 @@ define void @trunc_sat_u8u16_minmax(ptr %x, ptr %y) {
 ; CHECK-LABEL: trunc_sat_u8u16_minmax:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vl1re16.v v8, (a0)
-; CHECK-NEXT:    li a0, 255
-; CHECK-NEXT:    vsetvli a2, zero, e16, m1, ta, ma
-; CHECK-NEXT:    vmin.vx v8, v8, a0
+; CHECK-NEXT:    vsetvli a0, zero, e16, m1, ta, ma
 ; CHECK-NEXT:    vmax.vx v8, v8, zero
 ; CHECK-NEXT:    vsetvli zero, zero, e8, mf2, ta, ma
-; CHECK-NEXT:    vnsrl.wi v8, v8, 0
+; CHECK-NEXT:    vnclipu.wi v8, v8, 0
 ; CHECK-NEXT:    vse8.v v8, (a1)
 ; CHECK-NEXT:    ret
   %1 = load <vscale x 4 x i16>, ptr %x, align 16
@@ -237,11 +233,8 @@ define void @trunc_sat_u16u32_maxmin(ptr %x, ptr %y) {
 ; CHECK-NEXT:    li a0, 1
 ; CHECK-NEXT:    vsetvli a2, zero, e32, m2, ta, ma
 ; CHECK-NEXT:    vmax.vx v8, v8, a0
-; CHECK-NEXT:    lui a0, 16
-; CHECK-NEXT:    addi a0, a0, -1
-; CHECK-NEXT:    vmin.vx v8, v8, a0
 ; CHECK-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
-; CHECK-NEXT:    vnsrl.wi v10, v8, 0
+; CHECK-NEXT:    vnclipu.wi v10, v8, 0
 ; CHECK-NEXT:    vs1r.v v10, (a1)
 ; CHECK-NEXT:    ret
   %1 = load <vscale x 4 x i32>, ptr %x, align 16
@@ -257,14 +250,11 @@ define void @trunc_sat_u16u32_minmax(ptr %x, ptr %y) {
 ; CHECK-LABEL: trunc_sat_u16u32_minmax:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vl2re32.v v8, (a0)
-; CHECK-NEXT:    lui a0, 16
-; CHECK-NEXT:    addi a0, a0, -1
-; CHECK-NEXT:    vsetvli a2, zero, e32, m2, ta, ma
-; CHECK-NEXT:    vmin.vx v8, v8, a0
 ; CHECK-NEXT:    li a0, 50
+; CHECK-NEXT:    vsetvli a2, zero, e32, m2, ta, ma
 ; CHECK-NEXT:    vmax.vx v8, v8, a0
 ; CHECK-NEXT:    vsetvli zero, zero, e16, m1, ta, ma
-; CHECK-NEXT:    vnsrl.wi v10, v8, 0
+; CHECK-NEXT:    vnclipu.wi v10, v8, 0
 ; CHECK-NEXT:    vs1r.v v10, (a1)
 ; CHECK-NEXT:    ret
   %1 = load <vscale x 4 x i32>, ptr %x, align 16
@@ -374,11 +364,8 @@ define void @trunc_sat_u32u64_maxmin(ptr %x, ptr %y) {
 ; CHECK-NEXT:    vl4re64.v v8, (a0)
 ; CHECK-NEXT:    vsetvli a0, zero, e64, m4, ta, ma
 ; CHECK-NEXT:    vmax.vx v8, v8, zero
-; CHECK-NEXT:    li a0, -1
-; CHECK-NEXT:    srli a0, a0, 32
-; CHECK-NEXT:    vmin.vx v8, v8, a0
 ; CHECK-NEXT:    vsetvli zero, zero, e32, m2, ta, ma
-; CHECK-NEXT:    vnsrl.wi v12, v8, 0
+; CHECK-NEXT:    vnclipu.wi v12, v8, 0
 ; CHECK-NEXT:    vs2r.v v12, (a1)
 ; CHECK-NEXT:    ret
   %1 = load <vscale x 4 x i64>, ptr %x, align 16
@@ -394,13 +381,10 @@ define void @trunc_sat_u32u64_minmax(ptr %x, ptr %y) {
 ; CHECK-LABEL: trunc_sat_u32u64_minmax:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    vl4re64.v v8, (a0)
-; CHECK-NEXT:    li a0, -1
-; CHECK-NEXT:    srli a0, a0, 32
-; CHECK-NEXT:    vsetvli a2, zero, e64, m4, ta, ma
-; CHECK-NEXT:    vmin.vx v8, v8, a0
+; CHECK-NEXT:    vsetvli a0, zero, e64, m4, ta, ma
 ; CHECK-NEXT:    vmax.vx v8, v8, zero
 ; CHECK-NEXT:    vsetvli zero, zero, e32, m2, ta, ma
-; CHECK-NEXT:    vnsrl.wi v12, v8, 0
+; CHECK-NEXT:    vnclipu.wi v12, v8, 0
 ; CHECK-NEXT:    vs2r.v v12, (a1)
 ; CHECK-NEXT:    ret
   %1 = load <vscale x 4 x i64>, ptr %x, align 16

Copy link

github-actions bot commented May 30, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

…bineTruncToVnclip.

If the smax removed all negative numbers, then we can treat the
smin like a umin.

If the smin and smax are in the other order we can swap them and
use a vnclipu as long as the smax constant is smaller than the
smin constant.

This is based on similar code from X86's detectUSatPattern.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the FIXMEs above these functions can be removed now

if (SDValue SMin = MatchMinMax(V, ISD::SMIN, RISCVISD::SMIN_VL, HiC))
if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, RISCVISD::SMAX_VL, LoC))
if (LoC.isNonNegative() && HiC.isMask(VT.getScalarSizeInBits()))
return SMin;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to double check, SMin is the ISD::SMAX and SMax is the input? This is a nit but would it be clearer to rename it something more like SMinOp and SMaxOp

Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM w/minor comment. I agree with Luke's point on naming.

// If we have an SMAX that removes negative numbers first, then we can match
// SMIN instead of UMIN.
if (SDValue SMin = MatchMinMax(V, ISD::SMIN, RISCVISD::SMIN_VL, HiC))
if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, RISCVISD::SMAX_VL, LoC))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have a known positive input the SMIN, why is that not being canonicalized to a umin?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good question. I also ran these tests through InstCombine and it didn't convert the smin to umin either.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@preames See #82478 and #88505.

@topperc We did this canonicalization in CVP.

@topperc topperc merged commit 8247068 into llvm:main May 30, 2024
5 of 6 checks passed
@topperc topperc deleted the pr/smin-smax-vnclipu branch May 30, 2024 20:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants