-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RISCV] Support (truncate (smin (smax X, C1), C2)) for vnclipu in combineTruncToVnclip. #93756
Conversation
@llvm/pr-subscribers-backend-risc-v Author: Craig Topper (topperc) ChangesIf the smax removed all negative numbers, then we can treat the smin like a umin. If the smin and smax are in the other order we can swap them and use a vnclipu as long as the smax constant is smaller than the smin constant. This is based on similar code from X86's detectUSatPattern. Full diff: https://github.com/llvm/llvm-project/pull/93756.diff 3 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 0e7713509e969..c079bf5a16b94 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -16230,17 +16230,35 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
return SDValue();
};
+ SDLoc DL(N);
+
auto DetectUSatPattern = [&](SDValue V) {
- // Src must be a UMIN or UMIN_VL.
- APInt C;
- SDValue UMin = MatchMinMax(V, ISD::UMIN, RISCVISD::UMIN_VL, C);
- if (!UMin)
- return SDValue();
+ APInt LoC, HiC;
+
+ // Simple case, V is a UMIN.
+ if (SDValue UMin = MatchMinMax(V, ISD::UMIN, RISCVISD::UMIN_VL, HiC))
+ if (HiC.isMask(VT.getScalarSizeInBits()))
+ return UMin;
+
+ // If we have an SMAX that removes negative numbers first, then we can match
+ // SMIN instead of UMIN.
+ if (SDValue SMin = MatchMinMax(V, ISD::SMIN, RISCVISD::SMIN_VL, HiC))
+ if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, RISCVISD::SMAX_VL, LoC))
+ if (LoC.isNonNegative() && HiC.isMask(VT.getScalarSizeInBits()))
+ return SMin;
- if (!C.isMask(VT.getScalarSizeInBits()))
- return SDValue();
+ // If we have an SMIN before an SMAX and the SMAX constant is less than or
+ // equal to the SMIN constant, we can use vnclipu if we insert a new SMAX
+ // first.
+ if (SDValue SMax = MatchMinMax(V, ISD::SMAX, RISCVISD::SMAX_VL, LoC))
+ if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, RISCVISD::SMIN_VL, HiC))
+ if (LoC.isNonNegative() && HiC.isMask(VT.getScalarSizeInBits()) &&
+ HiC.uge(LoC))
+ return DAG.getNode(RISCVISD::SMAX_VL, DL, V.getValueType(),
+ SMin, V.getOperand(1), DAG.getUNDEF(V.getValueType()),
+ Mask, VL);
- return UMin;
+ return SDValue();
};
auto DetectSSatPattern = [&](SDValue V) {
@@ -16249,15 +16267,15 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
APInt SignedMax = APInt::getSignedMaxValue(NumDstBits).sext(NumSrcBits);
APInt SignedMin = APInt::getSignedMinValue(NumDstBits).sext(NumSrcBits);
- APInt CMin, CMax;
- if (SDValue SMin = MatchMinMax(V, ISD::SMIN, RISCVISD::SMIN_VL, CMin))
- if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, RISCVISD::SMAX_VL, CMax))
- if (CMin == SignedMax && CMax == SignedMin)
+ APInt HiC, LoC;
+ if (SDValue SMin = MatchMinMax(V, ISD::SMIN, RISCVISD::SMIN_VL, HiC))
+ if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, RISCVISD::SMAX_VL, LoC))
+ if (HiC == SignedMax && LoC == SignedMin)
return SMax;
- if (SDValue SMax = MatchMinMax(V, ISD::SMAX, RISCVISD::SMAX_VL, CMax))
- if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, RISCVISD::SMIN_VL, CMin))
- if (CMin == SignedMax && CMax == SignedMin)
+ if (SDValue SMax = MatchMinMax(V, ISD::SMAX, RISCVISD::SMAX_VL, LoC))
+ if (SDValue SMin = MatchMinMax(SMax, ISD::SMIN, RISCVISD::SMIN_VL, HiC))
+ if (HiC == SignedMax && LoC == SignedMin)
return SMin;
return SDValue();
@@ -16272,7 +16290,6 @@ static SDValue combineTruncToVnclip(SDNode *N, SelectionDAG &DAG,
else
return SDValue();
- SDLoc DL(N);
// Rounding mode here is arbitrary since we aren't shifting out any bits.
return DAG.getNode(
ClipOpc, DL, VT,
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-trunc-sat-clip.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-trunc-sat-clip.ll
index 9f82eddf432da..aaab2e5b3d97a 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-trunc-sat-clip.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-trunc-sat-clip.ll
@@ -105,10 +105,8 @@ define void @trunc_sat_u8u16_maxmin(ptr %x, ptr %y) {
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: vmax.vx v8, v8, zero
-; CHECK-NEXT: li a0, 255
-; CHECK-NEXT: vmin.vx v8, v8, a0
; CHECK-NEXT: vsetvli zero, zero, e8, mf4, ta, ma
-; CHECK-NEXT: vnsrl.wi v8, v8, 0
+; CHECK-NEXT: vnclipu.wi v8, v8, 0
; CHECK-NEXT: vse8.v v8, (a1)
; CHECK-NEXT: ret
%1 = load <4 x i16>, ptr %x, align 16
@@ -125,11 +123,9 @@ define void @trunc_sat_u8u16_minmax(ptr %x, ptr %y) {
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
; CHECK-NEXT: vle16.v v8, (a0)
-; CHECK-NEXT: li a0, 255
-; CHECK-NEXT: vmin.vx v8, v8, a0
; CHECK-NEXT: vmax.vx v8, v8, zero
; CHECK-NEXT: vsetvli zero, zero, e8, mf4, ta, ma
-; CHECK-NEXT: vnsrl.wi v8, v8, 0
+; CHECK-NEXT: vnclipu.wi v8, v8, 0
; CHECK-NEXT: vse8.v v8, (a1)
; CHECK-NEXT: ret
%1 = load <4 x i16>, ptr %x, align 16
@@ -237,11 +233,8 @@ define void @trunc_sat_u16u32_maxmin(ptr %x, ptr %y) {
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: li a0, 1
; CHECK-NEXT: vmax.vx v8, v8, a0
-; CHECK-NEXT: lui a0, 16
-; CHECK-NEXT: addi a0, a0, -1
-; CHECK-NEXT: vmin.vx v8, v8, a0
; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, ma
-; CHECK-NEXT: vnsrl.wi v8, v8, 0
+; CHECK-NEXT: vnclipu.wi v8, v8, 0
; CHECK-NEXT: vse16.v v8, (a1)
; CHECK-NEXT: ret
%1 = load <4 x i32>, ptr %x, align 16
@@ -258,13 +251,10 @@ define void @trunc_sat_u16u32_minmax(ptr %x, ptr %y) {
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, ma
; CHECK-NEXT: vle32.v v8, (a0)
-; CHECK-NEXT: lui a0, 16
-; CHECK-NEXT: addi a0, a0, -1
-; CHECK-NEXT: vmin.vx v8, v8, a0
; CHECK-NEXT: li a0, 50
; CHECK-NEXT: vmax.vx v8, v8, a0
; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, ma
-; CHECK-NEXT: vnsrl.wi v8, v8, 0
+; CHECK-NEXT: vnclipu.wi v8, v8, 0
; CHECK-NEXT: vse16.v v8, (a1)
; CHECK-NEXT: ret
%1 = load <4 x i32>, ptr %x, align 16
@@ -374,11 +364,8 @@ define void @trunc_sat_u32u64_maxmin(ptr %x, ptr %y) {
; CHECK-NEXT: vsetivli zero, 4, e64, m2, ta, ma
; CHECK-NEXT: vle64.v v8, (a0)
; CHECK-NEXT: vmax.vx v8, v8, zero
-; CHECK-NEXT: li a0, -1
-; CHECK-NEXT: srli a0, a0, 32
-; CHECK-NEXT: vmin.vx v8, v8, a0
; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT: vnsrl.wi v10, v8, 0
+; CHECK-NEXT: vnclipu.wi v10, v8, 0
; CHECK-NEXT: vse32.v v10, (a1)
; CHECK-NEXT: ret
%1 = load <4 x i64>, ptr %x, align 16
@@ -395,12 +382,9 @@ define void @trunc_sat_u32u64_minmax(ptr %x, ptr %y) {
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 4, e64, m2, ta, ma
; CHECK-NEXT: vle64.v v8, (a0)
-; CHECK-NEXT: li a0, -1
-; CHECK-NEXT: srli a0, a0, 32
-; CHECK-NEXT: vmin.vx v8, v8, a0
; CHECK-NEXT: vmax.vx v8, v8, zero
; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
-; CHECK-NEXT: vnsrl.wi v10, v8, 0
+; CHECK-NEXT: vnclipu.wi v10, v8, 0
; CHECK-NEXT: vse32.v v10, (a1)
; CHECK-NEXT: ret
%1 = load <4 x i64>, ptr %x, align 16
diff --git a/llvm/test/CodeGen/RISCV/rvv/trunc-sat-clip-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/trunc-sat-clip-sdnode.ll
index 78e8f0fbbbdd7..4ca99be1edb99 100644
--- a/llvm/test/CodeGen/RISCV/rvv/trunc-sat-clip-sdnode.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/trunc-sat-clip-sdnode.ll
@@ -105,10 +105,8 @@ define void @trunc_sat_u8u16_maxmin(ptr %x, ptr %y) {
; CHECK-NEXT: vl1re16.v v8, (a0)
; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, ma
; CHECK-NEXT: vmax.vx v8, v8, zero
-; CHECK-NEXT: li a0, 255
-; CHECK-NEXT: vmin.vx v8, v8, a0
; CHECK-NEXT: vsetvli zero, zero, e8, mf2, ta, ma
-; CHECK-NEXT: vnsrl.wi v8, v8, 0
+; CHECK-NEXT: vnclipu.wi v8, v8, 0
; CHECK-NEXT: vse8.v v8, (a1)
; CHECK-NEXT: ret
%1 = load <vscale x 4 x i16>, ptr %x, align 16
@@ -124,12 +122,10 @@ define void @trunc_sat_u8u16_minmax(ptr %x, ptr %y) {
; CHECK-LABEL: trunc_sat_u8u16_minmax:
; CHECK: # %bb.0:
; CHECK-NEXT: vl1re16.v v8, (a0)
-; CHECK-NEXT: li a0, 255
-; CHECK-NEXT: vsetvli a2, zero, e16, m1, ta, ma
-; CHECK-NEXT: vmin.vx v8, v8, a0
+; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, ma
; CHECK-NEXT: vmax.vx v8, v8, zero
; CHECK-NEXT: vsetvli zero, zero, e8, mf2, ta, ma
-; CHECK-NEXT: vnsrl.wi v8, v8, 0
+; CHECK-NEXT: vnclipu.wi v8, v8, 0
; CHECK-NEXT: vse8.v v8, (a1)
; CHECK-NEXT: ret
%1 = load <vscale x 4 x i16>, ptr %x, align 16
@@ -237,11 +233,8 @@ define void @trunc_sat_u16u32_maxmin(ptr %x, ptr %y) {
; CHECK-NEXT: li a0, 1
; CHECK-NEXT: vsetvli a2, zero, e32, m2, ta, ma
; CHECK-NEXT: vmax.vx v8, v8, a0
-; CHECK-NEXT: lui a0, 16
-; CHECK-NEXT: addi a0, a0, -1
-; CHECK-NEXT: vmin.vx v8, v8, a0
; CHECK-NEXT: vsetvli zero, zero, e16, m1, ta, ma
-; CHECK-NEXT: vnsrl.wi v10, v8, 0
+; CHECK-NEXT: vnclipu.wi v10, v8, 0
; CHECK-NEXT: vs1r.v v10, (a1)
; CHECK-NEXT: ret
%1 = load <vscale x 4 x i32>, ptr %x, align 16
@@ -257,14 +250,11 @@ define void @trunc_sat_u16u32_minmax(ptr %x, ptr %y) {
; CHECK-LABEL: trunc_sat_u16u32_minmax:
; CHECK: # %bb.0:
; CHECK-NEXT: vl2re32.v v8, (a0)
-; CHECK-NEXT: lui a0, 16
-; CHECK-NEXT: addi a0, a0, -1
-; CHECK-NEXT: vsetvli a2, zero, e32, m2, ta, ma
-; CHECK-NEXT: vmin.vx v8, v8, a0
; CHECK-NEXT: li a0, 50
+; CHECK-NEXT: vsetvli a2, zero, e32, m2, ta, ma
; CHECK-NEXT: vmax.vx v8, v8, a0
; CHECK-NEXT: vsetvli zero, zero, e16, m1, ta, ma
-; CHECK-NEXT: vnsrl.wi v10, v8, 0
+; CHECK-NEXT: vnclipu.wi v10, v8, 0
; CHECK-NEXT: vs1r.v v10, (a1)
; CHECK-NEXT: ret
%1 = load <vscale x 4 x i32>, ptr %x, align 16
@@ -374,11 +364,8 @@ define void @trunc_sat_u32u64_maxmin(ptr %x, ptr %y) {
; CHECK-NEXT: vl4re64.v v8, (a0)
; CHECK-NEXT: vsetvli a0, zero, e64, m4, ta, ma
; CHECK-NEXT: vmax.vx v8, v8, zero
-; CHECK-NEXT: li a0, -1
-; CHECK-NEXT: srli a0, a0, 32
-; CHECK-NEXT: vmin.vx v8, v8, a0
; CHECK-NEXT: vsetvli zero, zero, e32, m2, ta, ma
-; CHECK-NEXT: vnsrl.wi v12, v8, 0
+; CHECK-NEXT: vnclipu.wi v12, v8, 0
; CHECK-NEXT: vs2r.v v12, (a1)
; CHECK-NEXT: ret
%1 = load <vscale x 4 x i64>, ptr %x, align 16
@@ -394,13 +381,10 @@ define void @trunc_sat_u32u64_minmax(ptr %x, ptr %y) {
; CHECK-LABEL: trunc_sat_u32u64_minmax:
; CHECK: # %bb.0:
; CHECK-NEXT: vl4re64.v v8, (a0)
-; CHECK-NEXT: li a0, -1
-; CHECK-NEXT: srli a0, a0, 32
-; CHECK-NEXT: vsetvli a2, zero, e64, m4, ta, ma
-; CHECK-NEXT: vmin.vx v8, v8, a0
+; CHECK-NEXT: vsetvli a0, zero, e64, m4, ta, ma
; CHECK-NEXT: vmax.vx v8, v8, zero
; CHECK-NEXT: vsetvli zero, zero, e32, m2, ta, ma
-; CHECK-NEXT: vnsrl.wi v12, v8, 0
+; CHECK-NEXT: vnclipu.wi v12, v8, 0
; CHECK-NEXT: vs2r.v v12, (a1)
; CHECK-NEXT: ret
%1 = load <vscale x 4 x i64>, ptr %x, align 16
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
…bineTruncToVnclip. If the smax removed all negative numbers, then we can treat the smin like a umin. If the smin and smax are in the other order we can swap them and use a vnclipu as long as the smax constant is smaller than the smin constant. This is based on similar code from X86's detectUSatPattern.
3c91d72
to
951974e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the FIXMEs above these functions can be removed now
if (SDValue SMin = MatchMinMax(V, ISD::SMIN, RISCVISD::SMIN_VL, HiC)) | ||
if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, RISCVISD::SMAX_VL, LoC)) | ||
if (LoC.isNonNegative() && HiC.isMask(VT.getScalarSizeInBits())) | ||
return SMin; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to double check, SMin is the ISD::SMAX and SMax is the input? This is a nit but would it be clearer to rename it something more like SMinOp and SMaxOp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM w/minor comment. I agree with Luke's point on naming.
// If we have an SMAX that removes negative numbers first, then we can match | ||
// SMIN instead of UMIN. | ||
if (SDValue SMin = MatchMinMax(V, ISD::SMIN, RISCVISD::SMIN_VL, HiC)) | ||
if (SDValue SMax = MatchMinMax(SMin, ISD::SMAX, RISCVISD::SMAX_VL, LoC)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have a known positive input the SMIN, why is that not being canonicalized to a umin?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good question. I also ran these tests through InstCombine and it didn't convert the smin to umin either.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the smax removed all negative numbers, then we can treat the smin like a umin.
If the smin and smax are in the other order we can swap them and use a vnclipu as long as the smax constant is smaller than the smin constant.
This is based on similar code from X86's detectUSatPattern.