Skip to content

Optimize fptrunc(x)>=C1 --> x>=C2 #99475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

kissholic
Copy link

@kissholic kissholic requested a review from nikic as a code owner July 18, 2024 11:55
Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jul 18, 2024

@llvm/pr-subscribers-llvm-transforms

Author: None (kissholic)

Changes

Fix #85265 (comment)


Full diff: https://github.com/llvm/llvm-project/pull/99475.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+31)
  • (added) llvm/test/Transforms/InstCombine/fold-fcmp-trunc.ll (+11)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index abadf54a96767..2af3e92213f13 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -22,10 +22,13 @@
 #include "llvm/Analysis/Utils/Local.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/ConstantRange.h"
+#include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/KnownBits.h"
 #include "llvm/Transforms/InstCombine/InstCombiner.h"
 #include <bitset>
@@ -7882,6 +7885,30 @@ static Instruction *foldFCmpReciprocalAndZero(FCmpInst &I, Instruction *LHSI,
   return new FCmpInst(Pred, LHSI->getOperand(1), RHSC, "", &I);
 }
 
+// Fold trunc(x) < constant --> x < constant if possible.
+static Instruction *foldFCmpFpTrunc(FCmpInst &I, Instruction *LHSI,
+                                    Constant *RHSC) {
+  //
+  FCmpInst::Predicate Pred = I.getPredicate();
+
+  // Check that predicates are valid.
+  if ((Pred != FCmpInst::FCMP_OGT) && (Pred != FCmpInst::FCMP_OLT) &&
+      (Pred != FCmpInst::FCMP_OGE) && (Pred != FCmpInst::FCMP_OLE))
+    return nullptr;
+
+  auto *LType = LHSI->getOperand(0)->getType();
+  auto *RType = RHSC->getType();
+
+  if (!(LType->isFloatingPointTy() && RType->isFloatingPointTy() &&
+        LType->getTypeID() >= RType->getTypeID()))
+    return nullptr;
+
+  auto *ROperand = llvm::ConstantFP::get(
+      LType, dyn_cast<ConstantFP>(RHSC)->getValue().convertToDouble());
+
+  return new FCmpInst(Pred, LHSI->getOperand(0), ROperand, "", &I);
+}
+
 /// Optimize fabs(X) compared with zero.
 static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
   Value *X;
@@ -8244,6 +8271,10 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
                   cast<LoadInst>(LHSI), GEP, GV, I))
             return Res;
       break;
+    case Instruction::FPTrunc:
+      if (Instruction *NV = foldFCmpFpTrunc(I, LHSI, RHSC))
+        return NV;
+      break;
   }
   }
 
diff --git a/llvm/test/Transforms/InstCombine/fold-fcmp-trunc.ll b/llvm/test/Transforms/InstCombine/fold-fcmp-trunc.ll
new file mode 100644
index 0000000000000..446111a60dd6c
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/fold-fcmp-trunc.ll
@@ -0,0 +1,11 @@
+; RUN: opt -passes=instcombine -S < %s | FileCheck %s
+
+
+;CHECK-LABEL: @src(
+;CHECK: %result = fcmp oge double %0, 1.000000e+02
+;CHECK-NEXT: ret i1 %result
+define i1 @src(double %0) {
+    %trunc = fptrunc double %0 to float
+    %result = fcmp oge float %trunc, 1.000000e+02
+    ret i1 %result
+}
\ No newline at end of file

@dtcxzyw dtcxzyw requested a review from arsenm July 18, 2024 11:58
@dtcxzyw
Copy link
Member

dtcxzyw commented Jul 18, 2024

Please read the guideline https://llvm.org/docs/InstCombineContributorGuide.html.

%trunc = fptrunc double %0 to float
%result = fcmp oge float %trunc, 1.000000e+02
ret i1 %result
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing newline.

auto *LType = LHSI->getOperand(0)->getType();
auto *RType = RHSC->getType();

if (!(LType->isFloatingPointTy() && RType->isFloatingPointTy() &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need to check isFloatingPointTy, that's implied by the operations in use already. Also not sure what a >= means when comparing type IDs but you probably don't need that either

return nullptr;

auto *ROperand = llvm::ConstantFP::get(
LType, dyn_cast<ConstantFP>(RHSC)->getValue().convertToDouble());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use convertToDouble, keep this entirely in APFloat

%trunc = fptrunc double %0 to float
%result = fcmp oge float %trunc, 1.000000e+02
ret i1 %result
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test more combinations of types, including vectors. Also test flag preservation behavior

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test of fp128, x86_fp80, or ppc_fp128 in particular would be helpful.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test more combinations of types, including vectors. Also test flag preservation behavior
@arsenm Sorry, could you give me a hint what 'flag preservation behavior' means in IR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the flags on the compare should be preserved, and you don't have any tests using fast math flags. e.g. https://alive2.llvm.org/ce/z/uQr4-J

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test of fp128, x86_fp80, or ppc_fp128 in particular would be helpful.

Sorry for late.

I tried to test these types, but an error is generated whose message is "floating point constant does not have type 'fp128'".

The error is emitted in the parse stage, and more modifications might be required to be conducted.

Should i ignore this problem, or if there are better solutions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The IR syntax for floating point constants is unnecessarily painful. You need to use a different prefix and have the correct hex length depending on the format. For this I usually just write fpext from a reasonable FP type to the long one, and run it through instsimplify to see how it should be printed

@aengelke
Copy link
Contributor

This appears to be incorrect w.r.t. round-to-nearest rounding of fptrunc. alive2 The constant needs adjustment.

@kissholic
Copy link
Author

This appears to be incorrect w.r.t. round-to-nearest rounding of fptrunc. alive2 The constant needs adjustment.

It seems that the double type the fp constant converted to can express the same value with float type without lossing accuracy. The rmNearestTiesToEven has already been applied, and no difference appeared.🫣


if (RHSC->getType()->isVectorTy()) {
Type *LVecType = LHSI->getOperand(0)->getType();
Type *LEleType = dyn_cast<VectorType>(LVecType)->getElementType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unchecked dyn_cast. Use the dyn_cast in the if expression instead of using isVectorTy

auto *ROperand = llvm::ConstantFP::get(
LType, dyn_cast<ConstantFP>(RHSC)->getValue().convertToDouble());
std::vector<Constant *> EleVec(EleNum);
for (uint64_t Idx = 0; Idx < EleNum; ++Idx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is trying too hard to handle the non-splat case. Just use m_APFloat which handles splat vectors and scalars at the same time

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So i should split splat vector case from the normal cases, and combine it with scalars, if i understand correctly?😖

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't have to really worry about the vector case. If you use m_APFloat, it should just work. It will handle scalars and splat vectors

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't have to really worry about the vector case. If you use m_APFloat, it should just work. It will handle scalars and splat vectors

Sorry, i still didn't get the point... It seems that m_APFloat can't cover the non-splat vector cases.

I also read the icmp-trunc optimization code and ran an int vector test case, but it seems not work in the int vector case.

I came up with an idea that FPExtInst may be applied to the constant, and left the optimization to constant extension part 😋 (joking).

Could you give some more hints? Thank you <3

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that m_APFloat can't cover the non-splat vector cases.

Correct. I'm saying it's a waste of time, and will multiply the patch size, to handle non-splat cases. If you really want to handle non-splat cases, it should be a follow up after the simple patch

ret <4 x i1> %cmp
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a scalable vector test

@aengelke
Copy link
Contributor

It seems that the double type the fp constant converted to can express the same value with float type without lossing accuracy. The rmNearestTiesToEven has already been applied, and no difference appeared.🫣

It's the opposite direction that is problematic. Consider input 99.99999999. After the fptrunc, it will be value 100.0f, which is >= 100.0f. But 99.99999999 is not >= 100.0. You need to find the smallest(/largest) value of the larger floating-point type which, after truncation, is satisfies the condition.

The contributor guide also says that you should provide alive2 proofs that your transformation is correct. I provided a proof above that the transformation as implemented/tested now is not correct.

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the alive2 link proof to the description

@@ -7882,6 +7892,79 @@ static Instruction *foldFCmpReciprocalAndZero(FCmpInst &I, Instruction *LHSI,
return new FCmpInst(Pred, LHSI->getOperand(1), RHSC, "", &I);
}

// Fold trunc(x) < constant --> x < constant if possible.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fptrunc, not trunc

Comment on lines 7901 to 7905
if ((Pred == FCmpInst::FCMP_OGE) || (Pred == FCmpInst::FCMP_UGE) ||
(Pred == FCmpInst::FCMP_OLT) || (Pred == FCmpInst::FCMP_ULT))
RoundDown = true;
else if ((Pred == FCmpInst::FCMP_OGT) || (Pred == FCmpInst::FCMP_UGT) ||
(Pred == FCmpInst::FCMP_OLE) || (Pred == FCmpInst::FCMP_ULE))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need parentheses around all these == expressions

%result = fcmp fast oge float %trunc, 1.000000e+02
ret i1 %result
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The set of tested constants seems too simple for the complexity of the loop testing constant validity. Should have negative tests for off by one bit in each direction. Also test with the edge case constants (inf, nan) and some denormal values?

Comment on lines 7938 to 7939
APFloat DupUpBound = UpBound;
DupUpBound.next(true);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name UpBoundNext, or similar?

APFloat LowBound = RoundDown ? ExtNextRValue : ExtRValue;
APFloat UpBound = RoundDown ? ExtRValue : ExtNextRValue;

while (true) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop is the hard to review part and needs some comments explaining what constants are legal

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for late, being occupied by work.

The alive proof will be posted soon.

The edge cases (inf, nan) has been folded before entering this optimization (e.g. fcmp oge x inf --> fcmp oeq x inf). Maybe filtering these cases is a good idea?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not always safe to rely on those folding before hand

kissholic added a commit to kissholic/llvm-project that referenced this pull request Oct 20, 2024
@kissholic
Copy link
Author

kissholic commented Oct 20, 2024

Exclude nan and infinity from optimization. The two 'number' requires special comparison rules, and have been optimized well by other methods, which generate bool literal directly.

Also add special treatment for the max (and the min) representable float value, due to their next value is infinity.

alive2 proof:
https://alive2.llvm.org/ce/z/rphqKP
https://alive2.llvm.org/ce/z/UhU5nG
https://alive2.llvm.org/ce/z/v3zu93
https://alive2.llvm.org/ce/z/fQxsaA
https://alive2.llvm.org/ce/z/UYnaSb
https://alive2.llvm.org/ce/z/5VRZGg

@arsenm arsenm added the floating-point Floating-point math label Oct 20, 2024
%result = fcmp olt float %trunc, -3.4028234663852885981170418348451692544e38
ret i1 %result
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test the literal nan and inf cases

Comment on lines +7901 to +8047
if (Pred == FCmpInst::FCMP_OGE || Pred == FCmpInst::FCMP_UGE ||
Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_ULT)
RoundDown = true;
else if (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_UGT ||
Pred == FCmpInst::FCMP_OLE || Pred == FCmpInst::FCMP_ULE)
RoundDown = false;
else
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uge, ule cases not tested. Plus negative test for others

// Set the limit of ExtNextRValue.
if (NextRValue.isInfinity()) {
ExtNextRValue = ExtRValue * Two;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No braces. Can also defer construction of the 2 constant (or avoid it by using scalbn instead)

ExtNextRValue.convert(LEleType->getFltSemantics(),
APFloat::rmNearestTiesToEven, &lossInfo);

// Binary search to find the maximal (or minimal) value after RValue promotion.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you write this with std::lower_bound/std::upper_bound?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::lower_bound/std::upper_bound (or similar algorithm like std::binary_search) seems only accept ForwardIt type and there is likely no suitable substitution algorithm in LLVM too, which may requires constructing a new complex wrapper struct of APFloat, such as implementing name required iteration methods, calculating the mean of two APFloat (without constructing a new APFloat divisor), defining a comparation function and so on.

Considering the internal complexity of APFloat wrapper, it is simpler to keep the original one (i think)?

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 4a19be5d45e4b1e02c2512023151be5d56ef5744 5cc33ac5e033690481505cb722695fbf3d345478 --extensions cpp -- llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 8894a337edd..adaac13a2ad 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -7932,15 +7932,15 @@ static Instruction *foldFCmpFpTrunc(FCmpInst &I, Instruction *LHSI,
   ExtNextRValue.convert(LEleType->getFltSemantics(),
                         APFloat::rmNearestTiesToEven, &lossInfo);
 
-  // Binary search to find the maximal (or minimal) value after RValue promotion.
-  // RValue can't have special comparison rules, which means nan or inf is not
-  // allowed here.
+  // Binary search to find the maximal (or minimal) value after RValue
+  // promotion. RValue can't have special comparison rules, which means nan or
+  // inf is not allowed here.
   APFloat RoundValue{LEleType->getFltSemantics()};
   {
     APFloat Two{LEleType->getFltSemantics(), 2};
 
-    // The (negative) maximum of RValue will become infinity when rounded up (down).
-    // Set the limit of ExtNextRValue.
+    // The (negative) maximum of RValue will become infinity when rounded up
+    // (down). Set the limit of ExtNextRValue.
     if (NextRValue.isInfinity()) {
       ExtNextRValue = ExtRValue * Two;
     }

@kissholic
Copy link
Author

@arsenm
Copy link
Contributor

arsenm commented Mar 16, 2025

Why was this closed?

@kissholic
Copy link
Author

Why was this closed?

Sorry, the old commits seem to be blocked due to a merge operation. I tried to discard those commits and push it again.

@kissholic kissholic reopened this Mar 16, 2025
@kissholic
Copy link
Author

ping

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

missed optimization, fptrunc (x) >= C1 => x >= C2
6 participants