Skip to content

Add constant-folding for unary NVVM intrinsics #141233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

LewisCrawford
Copy link
Contributor

Add support for constant-folding numerous NVVM unary arithmetic intrinsics (including f, d, and ftz_f variants):

  • nvvm.ceil.*
  • nvvm.cos.approx.*
  • nvvm.ex2.approx.*
  • nvvm.fabs.*
  • nvvm.floor.*
  • nvvm.lg2.approx.*
  • nvvm.rcp.*
  • nvvm.round.*
  • nvvm.rsqrt.approx.*
  • nvvm.saturate.*
  • nvvm.sin.approx.*
  • nvvm.sqrt.f
  • nvvm.sqrt.rn.*
  • nvvm.sqrt.approx.*

@llvmbot
Copy link
Member

llvmbot commented May 23, 2025

@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-backend-nvptx

Author: Lewis Crawford (LewisCrawford)

Changes

Add support for constant-folding numerous NVVM unary arithmetic intrinsics (including f, d, and ftz_f variants):

  • nvvm.ceil.*
  • nvvm.cos.approx.*
  • nvvm.ex2.approx.*
  • nvvm.fabs.*
  • nvvm.floor.*
  • nvvm.lg2.approx.*
  • nvvm.rcp.*
  • nvvm.round.*
  • nvvm.rsqrt.approx.*
  • nvvm.saturate.*
  • nvvm.sin.approx.*
  • nvvm.sqrt.f
  • nvvm.sqrt.rn.*
  • nvvm.sqrt.approx.*

Patch is 45.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141233.diff

3 Files Affected:

  • (modified) llvm/include/llvm/IR/NVVMIntrinsicUtils.h (+122)
  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+207-4)
  • (added) llvm/test/Transforms/InstSimplify/const-fold-nvvm-unary-arithmetic.ll (+1035)
diff --git a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
index ce794e2573637..394430f5e5629 100644
--- a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
+++ b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
@@ -334,6 +334,128 @@ inline bool FMinFMaxIsXorSignAbs(Intrinsic::ID IntrinsicID) {
   return false;
 }
 
+inline bool UnaryMathIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_ceil_ftz_f:
+  case Intrinsic::nvvm_cos_approx_ftz_f:
+  case Intrinsic::nvvm_ex2_approx_ftz_f:
+  case Intrinsic::nvvm_fabs_ftz_f:
+  case Intrinsic::nvvm_floor_ftz_f:
+  case Intrinsic::nvvm_lg2_approx_ftz_f:
+  case Intrinsic::nvvm_round_ftz_f:
+  case Intrinsic::nvvm_rsqrt_approx_ftz_d:
+  case Intrinsic::nvvm_rsqrt_approx_ftz_f:
+  case Intrinsic::nvvm_saturate_ftz_f:
+  case Intrinsic::nvvm_sin_approx_ftz_f:
+  case Intrinsic::nvvm_sqrt_rn_ftz_f:
+  case Intrinsic::nvvm_sqrt_approx_ftz_f:
+    return true;
+  case Intrinsic::nvvm_ceil_f:
+  case Intrinsic::nvvm_ceil_d:
+  case Intrinsic::nvvm_cos_approx_f:
+  case Intrinsic::nvvm_ex2_approx_d:
+  case Intrinsic::nvvm_ex2_approx_f:
+  case Intrinsic::nvvm_fabs_d:
+  case Intrinsic::nvvm_fabs_f:
+  case Intrinsic::nvvm_floor_f:
+  case Intrinsic::nvvm_floor_d:
+  case Intrinsic::nvvm_lg2_approx_d:
+  case Intrinsic::nvvm_lg2_approx_f:
+  case Intrinsic::nvvm_round_f:
+  case Intrinsic::nvvm_round_d:
+  case Intrinsic::nvvm_rsqrt_approx_d:
+  case Intrinsic::nvvm_rsqrt_approx_f:
+  case Intrinsic::nvvm_saturate_d:
+  case Intrinsic::nvvm_saturate_f:
+  case Intrinsic::nvvm_sin_approx_f:
+  case Intrinsic::nvvm_sqrt_f:
+  case Intrinsic::nvvm_sqrt_rn_d:
+  case Intrinsic::nvvm_sqrt_rn_f:
+  case Intrinsic::nvvm_sqrt_approx_f:
+    return false;
+  }
+  llvm_unreachable("Checking FTZ flag for invalid unary intrinsic");
+  return false;
+}
+
+inline bool RCPShouldFTZ(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_rcp_rm_ftz_f:
+  case Intrinsic::nvvm_rcp_rn_ftz_f:
+  case Intrinsic::nvvm_rcp_rp_ftz_f:
+  case Intrinsic::nvvm_rcp_rz_ftz_f:
+  case Intrinsic::nvvm_rcp_approx_ftz_f:
+  case Intrinsic::nvvm_rcp_approx_ftz_d:
+    return true;
+  case Intrinsic::nvvm_rcp_rm_d:
+  case Intrinsic::nvvm_rcp_rm_f:
+  case Intrinsic::nvvm_rcp_rn_d:
+  case Intrinsic::nvvm_rcp_rn_f:
+  case Intrinsic::nvvm_rcp_rp_d:
+  case Intrinsic::nvvm_rcp_rp_f:
+  case Intrinsic::nvvm_rcp_rz_d:
+  case Intrinsic::nvvm_rcp_rz_f:
+    return false;
+  }
+  llvm_unreachable("Checking FTZ flag for invalid rcp intrinsic");
+  return false;
+}
+
+inline APFloat::roundingMode GetRCPRoundingMode(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_rcp_rm_f:
+  case Intrinsic::nvvm_rcp_rm_d:
+  case Intrinsic::nvvm_rcp_rm_ftz_f:
+    return APFloat::rmTowardNegative;
+
+  case Intrinsic::nvvm_rcp_approx_ftz_f:
+  case Intrinsic::nvvm_rcp_approx_ftz_d:
+  case Intrinsic::nvvm_rcp_rn_f:
+  case Intrinsic::nvvm_rcp_rn_d:
+  case Intrinsic::nvvm_rcp_rn_ftz_f:
+    return APFloat::rmNearestTiesToEven;
+
+  case Intrinsic::nvvm_rcp_rp_f:
+  case Intrinsic::nvvm_rcp_rp_d:
+  case Intrinsic::nvvm_rcp_rp_ftz_f:
+    return APFloat::rmNearestTiesToEven;
+
+  case Intrinsic::nvvm_rcp_rz_f:
+  case Intrinsic::nvvm_rcp_rz_d:
+  case Intrinsic::nvvm_rcp_rz_ftz_f:
+    return APFloat::rmTowardZero;
+  }
+  llvm_unreachable("Checking rounding mode for invalid rcp intrinsic");
+  return APFloat::roundingMode::Invalid;
+}
+
+inline bool RCPIsApprox(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_rcp_approx_ftz_f:
+  case Intrinsic::nvvm_rcp_approx_ftz_d:
+    return true;
+
+  case Intrinsic::nvvm_rcp_rm_f:
+  case Intrinsic::nvvm_rcp_rm_d:
+  case Intrinsic::nvvm_rcp_rm_ftz_f:
+
+  case Intrinsic::nvvm_rcp_rn_f:
+  case Intrinsic::nvvm_rcp_rn_d:
+  case Intrinsic::nvvm_rcp_rn_ftz_f:
+
+  case Intrinsic::nvvm_rcp_rp_f:
+  case Intrinsic::nvvm_rcp_rp_d:
+  case Intrinsic::nvvm_rcp_rp_ftz_f:
+
+  case Intrinsic::nvvm_rcp_rz_f:
+  case Intrinsic::nvvm_rcp_rz_d:
+  case Intrinsic::nvvm_rcp_rz_ftz_f:
+    return false;
+  }
+  llvm_unreachable("Checking approx flag for invalid rcp intrinsic");
+  return false;
+}
+
 } // namespace nvvm
 } // namespace llvm
 #endif // LLVM_IR_NVVMINTRINSICUTILS_H
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 412a0e8979193..d8d3b2fad8bd1 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1791,6 +1791,69 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
   case Intrinsic::nearbyint:
   case Intrinsic::rint:
   case Intrinsic::canonicalize:
+
+  // NVVM math intrinsics:
+  case Intrinsic::nvvm_ceil_d:
+  case Intrinsic::nvvm_ceil_f:
+  case Intrinsic::nvvm_ceil_ftz_f:
+
+  case Intrinsic::nvvm_cos_approx_f:
+  case Intrinsic::nvvm_cos_approx_ftz_f:
+
+  case Intrinsic::nvvm_ex2_approx_d:
+  case Intrinsic::nvvm_ex2_approx_f:
+  case Intrinsic::nvvm_ex2_approx_ftz_f:
+
+  case Intrinsic::nvvm_fabs_d:
+  case Intrinsic::nvvm_fabs_f:
+  case Intrinsic::nvvm_fabs_ftz_f:
+
+  case Intrinsic::nvvm_floor_d:
+  case Intrinsic::nvvm_floor_f:
+  case Intrinsic::nvvm_floor_ftz_f:
+
+  case Intrinsic::nvvm_lg2_approx_d:
+  case Intrinsic::nvvm_lg2_approx_f:
+  case Intrinsic::nvvm_lg2_approx_ftz_f:
+
+  case Intrinsic::nvvm_rcp_rm_d:
+  case Intrinsic::nvvm_rcp_rm_f:
+  case Intrinsic::nvvm_rcp_rm_ftz_f:
+  case Intrinsic::nvvm_rcp_rn_d:
+  case Intrinsic::nvvm_rcp_rn_f:
+  case Intrinsic::nvvm_rcp_rn_ftz_f:
+  case Intrinsic::nvvm_rcp_rp_d:
+  case Intrinsic::nvvm_rcp_rp_f:
+  case Intrinsic::nvvm_rcp_rp_ftz_f:
+  case Intrinsic::nvvm_rcp_rz_d:
+  case Intrinsic::nvvm_rcp_rz_f:
+  case Intrinsic::nvvm_rcp_rz_ftz_f:
+  case Intrinsic::nvvm_rcp_approx_ftz_d:
+  case Intrinsic::nvvm_rcp_approx_ftz_f:
+
+  case Intrinsic::nvvm_round_d:
+  case Intrinsic::nvvm_round_f:
+  case Intrinsic::nvvm_round_ftz_f:
+
+  case Intrinsic::nvvm_rsqrt_approx_d:
+  case Intrinsic::nvvm_rsqrt_approx_f:
+  case Intrinsic::nvvm_rsqrt_approx_ftz_d:
+  case Intrinsic::nvvm_rsqrt_approx_ftz_f:
+
+  case Intrinsic::nvvm_saturate_d:
+  case Intrinsic::nvvm_saturate_f:
+  case Intrinsic::nvvm_saturate_ftz_f:
+
+  case Intrinsic::nvvm_sin_approx_f:
+  case Intrinsic::nvvm_sin_approx_ftz_f:
+
+  case Intrinsic::nvvm_sqrt_f:
+  case Intrinsic::nvvm_sqrt_rn_d:
+  case Intrinsic::nvvm_sqrt_rn_f:
+  case Intrinsic::nvvm_sqrt_rn_ftz_f:
+  case Intrinsic::nvvm_sqrt_approx_f:
+  case Intrinsic::nvvm_sqrt_approx_ftz_f:
+
   // Constrained intrinsics can be folded if FP environment is known
   // to compiler.
   case Intrinsic::experimental_constrained_fma:
@@ -1944,16 +2007,23 @@ static const APFloat FTZPreserveSign(const APFloat &V) {
   return V;
 }
 
-Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V,
-                         Type *Ty) {
+Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V, Type *Ty,
+                         bool ShouldFTZPreservingSign = false) {
   llvm_fenv_clearexcept();
-  double Result = NativeFP(V.convertToDouble());
+  auto Input = ShouldFTZPreservingSign ? FTZPreserveSign(V) : V;
+  double Result = NativeFP(Input.convertToDouble());
   if (llvm_fenv_testexcept()) {
     llvm_fenv_clearexcept();
     return nullptr;
   }
 
-  return GetConstantFoldFPValue(Result, Ty);
+  Constant *Output = GetConstantFoldFPValue(Result, Ty);
+  if (ShouldFTZPreservingSign) {
+    const auto *CFP = static_cast<ConstantFP *>(Output);
+    return ConstantFP::get(Ty->getContext(),
+                           FTZPreserveSign(CFP->getValueAPF()));
+  }
+  return Output;
 }
 
 #if defined(HAS_IEE754_FLOAT128) && defined(HAS_LOGF128)
@@ -2524,6 +2594,139 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
         return ConstantFoldFP(cosh, APF, Ty);
       case Intrinsic::sqrt:
         return ConstantFoldFP(sqrt, APF, Ty);
+
+      // NVVM Intrinsics:
+      case Intrinsic::nvvm_ceil_ftz_f:
+      case Intrinsic::nvvm_ceil_f:
+      case Intrinsic::nvvm_ceil_d:
+        return ConstantFoldFP(ceil, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+
+      case Intrinsic::nvvm_cos_approx_ftz_f:
+      case Intrinsic::nvvm_cos_approx_f:
+        return ConstantFoldFP(cos, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+
+      case Intrinsic::nvvm_ex2_approx_ftz_f:
+      case Intrinsic::nvvm_ex2_approx_d:
+      case Intrinsic::nvvm_ex2_approx_f:
+        return ConstantFoldFP(exp2, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+
+      case Intrinsic::nvvm_fabs_ftz_f:
+      case Intrinsic::nvvm_fabs_d:
+      case Intrinsic::nvvm_fabs_f:
+        return ConstantFoldFP(fabs, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+
+      case Intrinsic::nvvm_floor_ftz_f:
+      case Intrinsic::nvvm_floor_f:
+      case Intrinsic::nvvm_floor_d:
+        return ConstantFoldFP(floor, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+
+      case Intrinsic::nvvm_lg2_approx_ftz_f:
+      case Intrinsic::nvvm_lg2_approx_d:
+      case Intrinsic::nvvm_lg2_approx_f: {
+        if (APF.isNegative() || APF.isZero())
+          return nullptr;
+        return ConstantFoldFP(log2, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+      }
+
+      case Intrinsic::nvvm_rcp_rm_ftz_f:
+      case Intrinsic::nvvm_rcp_rn_ftz_f:
+      case Intrinsic::nvvm_rcp_rp_ftz_f:
+      case Intrinsic::nvvm_rcp_rz_ftz_f:
+      case Intrinsic::nvvm_rcp_approx_ftz_f:
+      case Intrinsic::nvvm_rcp_approx_ftz_d:
+      case Intrinsic::nvvm_rcp_rm_d:
+      case Intrinsic::nvvm_rcp_rm_f:
+      case Intrinsic::nvvm_rcp_rn_d:
+      case Intrinsic::nvvm_rcp_rn_f:
+      case Intrinsic::nvvm_rcp_rp_d:
+      case Intrinsic::nvvm_rcp_rp_f:
+      case Intrinsic::nvvm_rcp_rz_d:
+      case Intrinsic::nvvm_rcp_rz_f: {
+        APFloat::roundingMode RoundMode = nvvm::GetRCPRoundingMode(IntrinsicID);
+        bool IsApprox = nvvm::RCPIsApprox(IntrinsicID);
+        bool IsFTZ = nvvm::RCPShouldFTZ(IntrinsicID);
+
+        auto Denominator = IsFTZ ? FTZPreserveSign(APF) : APF;
+        if (IsApprox && Denominator.isZero()) {
+          // According to the PTX spec, approximate rcp should return infinity
+          // with the same sign as the denominator when dividing by 0.
+          APFloat Inf = APFloat::getInf(APF.getSemantics(), APF.isNegative());
+          return ConstantFP::get(Ty->getContext(), Inf);
+        }
+        APFloat Res = APFloat::getOne(APF.getSemantics());
+        APFloat::opStatus Status = Res.divide(Denominator, RoundMode);
+
+        if (Status == APFloat::opOK || Status == APFloat::opInexact) {
+          if (IsFTZ)
+            Res = FTZPreserveSign(Res);
+          return ConstantFP::get(Ty->getContext(), Res);
+        }
+        return nullptr;
+      }
+
+      case Intrinsic::nvvm_round_ftz_f:
+      case Intrinsic::nvvm_round_f:
+      case Intrinsic::nvvm_round_d:
+        return ConstantFoldFP(round, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+
+      case Intrinsic::nvvm_rsqrt_approx_ftz_d:
+      case Intrinsic::nvvm_rsqrt_approx_ftz_f:
+      case Intrinsic::nvvm_rsqrt_approx_d:
+      case Intrinsic::nvvm_rsqrt_approx_f: {
+        bool IsFTZ = nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID);
+        auto V = IsFTZ ? FTZPreserveSign(APF) : APF;
+        APFloat SqrtV(sqrt(V.convertToDouble()));
+
+        bool lost;
+        SqrtV.convert(APF.getSemantics(), APFloat::rmNearestTiesToEven, &lost);
+
+        APFloat Res = APFloat::getOne(APF.getSemantics());
+        Res.divide(SqrtV, APFloat::rmNearestTiesToEven);
+
+        // We do not need to flush the output for ftz because it is impossible
+        // for 1/sqrt(x) to be a denormal value. If x is the largest fp value,
+        // sqrt(x) will be a number with the exponent approximately halved and
+        // the reciprocal of that number can't be small enough to be denormal.
+        return ConstantFP::get(Ty->getContext(), Res);
+      }
+
+      case Intrinsic::nvvm_saturate_ftz_f:
+      case Intrinsic::nvvm_saturate_d:
+      case Intrinsic::nvvm_saturate_f: {
+        bool IsFTZ = nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID);
+        auto V = IsFTZ ? FTZPreserveSign(APF) : APF;
+        if (V.isNegative() || V.isZero() || V.isNaN())
+          return ConstantFP::getZero(Ty);
+        APFloat One = APFloat::getOne(APF.getSemantics());
+        if (V > One)
+          return ConstantFP::get(Ty->getContext(), One);
+        return ConstantFP::get(Ty->getContext(), APF);
+      }
+
+      case Intrinsic::nvvm_sin_approx_ftz_f:
+      case Intrinsic::nvvm_sin_approx_f:
+        return ConstantFoldFP(sin, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+
+      case Intrinsic::nvvm_sqrt_rn_ftz_f:
+      case Intrinsic::nvvm_sqrt_approx_ftz_f:
+      case Intrinsic::nvvm_sqrt_f:
+      case Intrinsic::nvvm_sqrt_rn_d:
+      case Intrinsic::nvvm_sqrt_rn_f:
+      case Intrinsic::nvvm_sqrt_approx_f:
+        if (APF.isNegative())
+          return nullptr;
+        return ConstantFoldFP(sqrt, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+
+      // AMDGCN Intrinsics:
       case Intrinsic::amdgcn_cos:
       case Intrinsic::amdgcn_sin: {
         double V = getValueAsDouble(Op);
diff --git a/llvm/test/Transforms/InstSimplify/const-fold-nvvm-unary-arithmetic.ll b/llvm/test/Transforms/InstSimplify/const-fold-nvvm-unary-arithmetic.ll
new file mode 100644
index 0000000000000..51563975be233
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/const-fold-nvvm-unary-arithmetic.ll
@@ -0,0 +1,1035 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instsimplify -march=nvptx64 -S | FileCheck %s
+
+; Test constant-folding for various NVVM unary arithmetic intrinsics.
+
+;###############################################################
+;#                          Ceil                               #
+;###############################################################
+
+define double @test_ceil_d_1_25() {
+; CHECK-LABEL: define double @test_ceil_d_1_25() {
+; CHECK-NEXT:    ret double 2.000000e+00
+;
+  %res = call double @llvm.nvvm.ceil.d(double 1.25)
+  ret double %res
+}
+
+define float @test_ceil_f_1_25() {
+; CHECK-LABEL: define float @test_ceil_f_1_25() {
+; CHECK-NEXT:    ret float 2.000000e+00
+;
+  %res = call float @llvm.nvvm.ceil.f(float 1.25)
+  ret float %res
+}
+
+define float @test_ceil_ftz_f_1_25() {
+; CHECK-LABEL: define float @test_ceil_ftz_f_1_25() {
+; CHECK-NEXT:    ret float 2.000000e+00
+;
+  %res = call float @llvm.nvvm.ceil.ftz.f(float 1.25)
+  ret float %res
+}
+
+define double @test_ceil_d_pos_subnorm() {
+; CHECK-LABEL: define double @test_ceil_d_pos_subnorm() {
+; CHECK-NEXT:    ret double 1.000000e+00
+;
+  %res = call double @llvm.nvvm.ceil.d(double 0x380FFFFFC0000000)
+  ret double %res
+}
+
+define float @test_ceil_f_pos_subnorm() {
+; CHECK-LABEL: define float @test_ceil_f_pos_subnorm() {
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %res = call float @llvm.nvvm.ceil.f(float 0x380FFFFFC0000000)
+  ret float %res
+}
+
+define float @test_ceil_ftz_f_pos_subnorm() {
+; CHECK-LABEL: define float @test_ceil_ftz_f_pos_subnorm() {
+; CHECK-NEXT:    ret float 0.000000e+00
+;
+  %res = call float @llvm.nvvm.ceil.ftz.f(float 0x380FFFFFC0000000)
+  ret float %res
+}
+
+;###############################################################
+;#                        Cos Approx                           #
+;###############################################################
+
+define float @test_cos_approx_f_1_25() {
+; CHECK-LABEL: define float @test_cos_approx_f_1_25() {
+; CHECK-NEXT:    ret float 0x3FD42E3DE0000000
+;
+  %res = call float @llvm.nvvm.cos.approx.f(float 1.25)
+  ret float %res
+}
+
+define float @test_cos_approx_ftz_f_1_25() {
+; CHECK-LABEL: define float @test_cos_approx_ftz_f_1_25() {
+; CHECK-NEXT:    ret float 0x3FD42E3DE0000000
+;
+  %res = call float @llvm.nvvm.cos.approx.ftz.f(float 1.25)
+  ret float %res
+}
+
+define float @test_cos_approx_f_pos_subnorm() {
+; CHECK-LABEL: define float @test_cos_approx_f_pos_subnorm() {
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %res = call float @llvm.nvvm.cos.approx.f(float 0x380FFFFFC0000000)
+  ret float %res
+}
+
+define float @test_cos_approx_ftz_f_pos_subnorm() {
+; CHECK-LABEL: define float @test_cos_approx_ftz_f_pos_subnorm() {
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %res = call float @llvm.nvvm.cos.approx.ftz.f(float 0x380FFFFFC0000000)
+  ret float %res
+}
+
+;###############################################################
+;#                        Ex2 Approx                           #
+;###############################################################
+
+define double @test_ex2_approx_d_1_25() {
+; CHECK-LABEL: define double @test_ex2_approx_d_1_25() {
+; CHECK-NEXT:    ret double 0x400306FE0A31B715
+;
+  %res = call double @llvm.nvvm.ex2.approx.d(double 1.25)
+  ret double %res
+}
+
+define float @test_ex2_approx_f_1_25() {
+; CHECK-LABEL: define float @test_ex2_approx_f_1_25() {
+; CHECK-NEXT:    ret float 0x400306FE00000000
+;
+  %res = call float @llvm.nvvm.ex2.approx.f(float 1.25)
+  ret float %res
+}
+
+define float @test_ex2_approx_ftz_f_1_25() {
+; CHECK-LABEL: define float @test_ex2_approx_ftz_f_1_25() {
+; CHECK-NEXT:    ret float 0x400306FE00000000
+;
+  %res = call float @llvm.nvvm.ex2.approx.ftz.f(float 1.25)
+  ret float %res
+}
+
+define double @test_ex2_approx_d_pos_subnorm() {
+; CHECK-LABEL: define double @test_ex2_approx_d_pos_subnorm() {
+; CHECK-NEXT:    ret double 1.000000e+00
+;
+  %res = call double @llvm.nvvm.ex2.approx.d(double 0x380FFFFFC0000000)
+  ret double %res
+}
+
+define float @test_ex2_approx_f_pos_subnorm() {
+; CHECK-LABEL: define float @test_ex2_approx_f_pos_subnorm() {
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %res = call float @llvm.nvvm.ex2.approx.f(float 0x380FFFFFC0000000)
+  ret float %res
+}
+
+define float @test_ex2_approx_ftz_f_pos_subnorm() {
+; CHECK-LABEL: define float @test_ex2_approx_ftz_f_pos_subnorm() {
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %res = call float @llvm.nvvm.ex2.approx.ftz.f(float 0x380FFFFFC0000000)
+  ret float %res
+}
+
+;###############################################################
+;#                          FAbs                               #
+;###############################################################
+define double @test_fabs_d_neg_1_5() {
+; CHECK-LABEL: define double @test_fabs_d_neg_1_5() {
+; CHECK-NEXT:    ret double 1.500000e+00
+;
+  %res = call double @llvm.nvvm.fabs.d(double -1.5)
+  ret double %res
+}
+
+define float @test_fabs_f_neg_1_5() {
+; CHECK-LABEL: define float @test_fabs_f_neg_1_5() {
+; CHECK-NEXT:    ret float 1.500000e+00
+;
+  %res = call float @llvm.nvvm.fabs.f(float -1.5)
+  ret float %res
+}
+
+define float @test_fabs_ftz_f_neg_1_5() {
+; CHECK-LABEL: define float @test_fabs_ftz_f_neg_1_5() {
+; CHECK-NEXT:    ret float 1.500000e+00
+;
+  %res = call float @llvm.nvvm.fabs.ftz.f(float -1.5)
+  ret float %res
+}
+
+define double @test_fabs_d_1_25() {
+; CHECK-LABEL: define double @test_fabs_d_1_25() {
+; CHECK-NEXT:    ret double 1.250000e+00
+;
+  %res = call double @llvm.nvvm.fabs.d(double 1.25)
+  ret double %res
+}
+
+define float @test_fabs_f_1_25() {
+; CHECK-LABEL: define float @test_fabs_f_1_25() {
+; CHECK-NEXT:    ret float 1.250000e+00
+;
+  %res = call float @llvm.nvvm.fabs.f(float 1.25)
+  ret float %res
+}
+
+define float @test_fabs_ftz_f_1_25() {
+; CHECK-LABEL: define float @test_fabs_ftz_f_1_25() {
+; CHECK-NEXT:    ret float 1.250000e+00
+;
+  %res = call float @llvm.nvvm.fabs.ftz.f(float 1.25)
+  ret float %res
+}
+
+define double @test_fabs_d_neg_subnorm() {
+; CHECK-LABEL: define double @test_fabs_d_neg_subnorm() {
+; CHECK-NEXT:    ret double 0x380FFFFFC0000000
+;
+  %res = call double @llvm.nvvm.fabs.d(double 0xB80FFFFFC0000000)
+  ret double %res
+}
+
+define float @test_fabs_f_neg_subnorm() {
+; CHECK-LABEL: define float @test_fabs_f_neg_subnorm() {
+; CHECK-NEXT:    ret floa...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 23, 2025

@llvm/pr-subscribers-llvm-ir

Author: Lewis Crawford (LewisCrawford)

Changes

Add support for constant-folding numerous NVVM unary arithmetic intrinsics (including f, d, and ftz_f variants):

  • nvvm.ceil.*
  • nvvm.cos.approx.*
  • nvvm.ex2.approx.*
  • nvvm.fabs.*
  • nvvm.floor.*
  • nvvm.lg2.approx.*
  • nvvm.rcp.*
  • nvvm.round.*
  • nvvm.rsqrt.approx.*
  • nvvm.saturate.*
  • nvvm.sin.approx.*
  • nvvm.sqrt.f
  • nvvm.sqrt.rn.*
  • nvvm.sqrt.approx.*

Patch is 45.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141233.diff

3 Files Affected:

  • (modified) llvm/include/llvm/IR/NVVMIntrinsicUtils.h (+122)
  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+207-4)
  • (added) llvm/test/Transforms/InstSimplify/const-fold-nvvm-unary-arithmetic.ll (+1035)
diff --git a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
index ce794e2573637..394430f5e5629 100644
--- a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
+++ b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
@@ -334,6 +334,128 @@ inline bool FMinFMaxIsXorSignAbs(Intrinsic::ID IntrinsicID) {
   return false;
 }
 
+inline bool UnaryMathIntrinsicShouldFTZ(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_ceil_ftz_f:
+  case Intrinsic::nvvm_cos_approx_ftz_f:
+  case Intrinsic::nvvm_ex2_approx_ftz_f:
+  case Intrinsic::nvvm_fabs_ftz_f:
+  case Intrinsic::nvvm_floor_ftz_f:
+  case Intrinsic::nvvm_lg2_approx_ftz_f:
+  case Intrinsic::nvvm_round_ftz_f:
+  case Intrinsic::nvvm_rsqrt_approx_ftz_d:
+  case Intrinsic::nvvm_rsqrt_approx_ftz_f:
+  case Intrinsic::nvvm_saturate_ftz_f:
+  case Intrinsic::nvvm_sin_approx_ftz_f:
+  case Intrinsic::nvvm_sqrt_rn_ftz_f:
+  case Intrinsic::nvvm_sqrt_approx_ftz_f:
+    return true;
+  case Intrinsic::nvvm_ceil_f:
+  case Intrinsic::nvvm_ceil_d:
+  case Intrinsic::nvvm_cos_approx_f:
+  case Intrinsic::nvvm_ex2_approx_d:
+  case Intrinsic::nvvm_ex2_approx_f:
+  case Intrinsic::nvvm_fabs_d:
+  case Intrinsic::nvvm_fabs_f:
+  case Intrinsic::nvvm_floor_f:
+  case Intrinsic::nvvm_floor_d:
+  case Intrinsic::nvvm_lg2_approx_d:
+  case Intrinsic::nvvm_lg2_approx_f:
+  case Intrinsic::nvvm_round_f:
+  case Intrinsic::nvvm_round_d:
+  case Intrinsic::nvvm_rsqrt_approx_d:
+  case Intrinsic::nvvm_rsqrt_approx_f:
+  case Intrinsic::nvvm_saturate_d:
+  case Intrinsic::nvvm_saturate_f:
+  case Intrinsic::nvvm_sin_approx_f:
+  case Intrinsic::nvvm_sqrt_f:
+  case Intrinsic::nvvm_sqrt_rn_d:
+  case Intrinsic::nvvm_sqrt_rn_f:
+  case Intrinsic::nvvm_sqrt_approx_f:
+    return false;
+  }
+  llvm_unreachable("Checking FTZ flag for invalid unary intrinsic");
+  return false;
+}
+
+inline bool RCPShouldFTZ(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_rcp_rm_ftz_f:
+  case Intrinsic::nvvm_rcp_rn_ftz_f:
+  case Intrinsic::nvvm_rcp_rp_ftz_f:
+  case Intrinsic::nvvm_rcp_rz_ftz_f:
+  case Intrinsic::nvvm_rcp_approx_ftz_f:
+  case Intrinsic::nvvm_rcp_approx_ftz_d:
+    return true;
+  case Intrinsic::nvvm_rcp_rm_d:
+  case Intrinsic::nvvm_rcp_rm_f:
+  case Intrinsic::nvvm_rcp_rn_d:
+  case Intrinsic::nvvm_rcp_rn_f:
+  case Intrinsic::nvvm_rcp_rp_d:
+  case Intrinsic::nvvm_rcp_rp_f:
+  case Intrinsic::nvvm_rcp_rz_d:
+  case Intrinsic::nvvm_rcp_rz_f:
+    return false;
+  }
+  llvm_unreachable("Checking FTZ flag for invalid rcp intrinsic");
+  return false;
+}
+
+inline APFloat::roundingMode GetRCPRoundingMode(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_rcp_rm_f:
+  case Intrinsic::nvvm_rcp_rm_d:
+  case Intrinsic::nvvm_rcp_rm_ftz_f:
+    return APFloat::rmTowardNegative;
+
+  case Intrinsic::nvvm_rcp_approx_ftz_f:
+  case Intrinsic::nvvm_rcp_approx_ftz_d:
+  case Intrinsic::nvvm_rcp_rn_f:
+  case Intrinsic::nvvm_rcp_rn_d:
+  case Intrinsic::nvvm_rcp_rn_ftz_f:
+    return APFloat::rmNearestTiesToEven;
+
+  case Intrinsic::nvvm_rcp_rp_f:
+  case Intrinsic::nvvm_rcp_rp_d:
+  case Intrinsic::nvvm_rcp_rp_ftz_f:
+    return APFloat::rmNearestTiesToEven;
+
+  case Intrinsic::nvvm_rcp_rz_f:
+  case Intrinsic::nvvm_rcp_rz_d:
+  case Intrinsic::nvvm_rcp_rz_ftz_f:
+    return APFloat::rmTowardZero;
+  }
+  llvm_unreachable("Checking rounding mode for invalid rcp intrinsic");
+  return APFloat::roundingMode::Invalid;
+}
+
+inline bool RCPIsApprox(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_rcp_approx_ftz_f:
+  case Intrinsic::nvvm_rcp_approx_ftz_d:
+    return true;
+
+  case Intrinsic::nvvm_rcp_rm_f:
+  case Intrinsic::nvvm_rcp_rm_d:
+  case Intrinsic::nvvm_rcp_rm_ftz_f:
+
+  case Intrinsic::nvvm_rcp_rn_f:
+  case Intrinsic::nvvm_rcp_rn_d:
+  case Intrinsic::nvvm_rcp_rn_ftz_f:
+
+  case Intrinsic::nvvm_rcp_rp_f:
+  case Intrinsic::nvvm_rcp_rp_d:
+  case Intrinsic::nvvm_rcp_rp_ftz_f:
+
+  case Intrinsic::nvvm_rcp_rz_f:
+  case Intrinsic::nvvm_rcp_rz_d:
+  case Intrinsic::nvvm_rcp_rz_ftz_f:
+    return false;
+  }
+  llvm_unreachable("Checking approx flag for invalid rcp intrinsic");
+  return false;
+}
+
 } // namespace nvvm
 } // namespace llvm
 #endif // LLVM_IR_NVVMINTRINSICUTILS_H
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 412a0e8979193..d8d3b2fad8bd1 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1791,6 +1791,69 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
   case Intrinsic::nearbyint:
   case Intrinsic::rint:
   case Intrinsic::canonicalize:
+
+  // NVVM math intrinsics:
+  case Intrinsic::nvvm_ceil_d:
+  case Intrinsic::nvvm_ceil_f:
+  case Intrinsic::nvvm_ceil_ftz_f:
+
+  case Intrinsic::nvvm_cos_approx_f:
+  case Intrinsic::nvvm_cos_approx_ftz_f:
+
+  case Intrinsic::nvvm_ex2_approx_d:
+  case Intrinsic::nvvm_ex2_approx_f:
+  case Intrinsic::nvvm_ex2_approx_ftz_f:
+
+  case Intrinsic::nvvm_fabs_d:
+  case Intrinsic::nvvm_fabs_f:
+  case Intrinsic::nvvm_fabs_ftz_f:
+
+  case Intrinsic::nvvm_floor_d:
+  case Intrinsic::nvvm_floor_f:
+  case Intrinsic::nvvm_floor_ftz_f:
+
+  case Intrinsic::nvvm_lg2_approx_d:
+  case Intrinsic::nvvm_lg2_approx_f:
+  case Intrinsic::nvvm_lg2_approx_ftz_f:
+
+  case Intrinsic::nvvm_rcp_rm_d:
+  case Intrinsic::nvvm_rcp_rm_f:
+  case Intrinsic::nvvm_rcp_rm_ftz_f:
+  case Intrinsic::nvvm_rcp_rn_d:
+  case Intrinsic::nvvm_rcp_rn_f:
+  case Intrinsic::nvvm_rcp_rn_ftz_f:
+  case Intrinsic::nvvm_rcp_rp_d:
+  case Intrinsic::nvvm_rcp_rp_f:
+  case Intrinsic::nvvm_rcp_rp_ftz_f:
+  case Intrinsic::nvvm_rcp_rz_d:
+  case Intrinsic::nvvm_rcp_rz_f:
+  case Intrinsic::nvvm_rcp_rz_ftz_f:
+  case Intrinsic::nvvm_rcp_approx_ftz_d:
+  case Intrinsic::nvvm_rcp_approx_ftz_f:
+
+  case Intrinsic::nvvm_round_d:
+  case Intrinsic::nvvm_round_f:
+  case Intrinsic::nvvm_round_ftz_f:
+
+  case Intrinsic::nvvm_rsqrt_approx_d:
+  case Intrinsic::nvvm_rsqrt_approx_f:
+  case Intrinsic::nvvm_rsqrt_approx_ftz_d:
+  case Intrinsic::nvvm_rsqrt_approx_ftz_f:
+
+  case Intrinsic::nvvm_saturate_d:
+  case Intrinsic::nvvm_saturate_f:
+  case Intrinsic::nvvm_saturate_ftz_f:
+
+  case Intrinsic::nvvm_sin_approx_f:
+  case Intrinsic::nvvm_sin_approx_ftz_f:
+
+  case Intrinsic::nvvm_sqrt_f:
+  case Intrinsic::nvvm_sqrt_rn_d:
+  case Intrinsic::nvvm_sqrt_rn_f:
+  case Intrinsic::nvvm_sqrt_rn_ftz_f:
+  case Intrinsic::nvvm_sqrt_approx_f:
+  case Intrinsic::nvvm_sqrt_approx_ftz_f:
+
   // Constrained intrinsics can be folded if FP environment is known
   // to compiler.
   case Intrinsic::experimental_constrained_fma:
@@ -1944,16 +2007,23 @@ static const APFloat FTZPreserveSign(const APFloat &V) {
   return V;
 }
 
-Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V,
-                         Type *Ty) {
+Constant *ConstantFoldFP(double (*NativeFP)(double), const APFloat &V, Type *Ty,
+                         bool ShouldFTZPreservingSign = false) {
   llvm_fenv_clearexcept();
-  double Result = NativeFP(V.convertToDouble());
+  auto Input = ShouldFTZPreservingSign ? FTZPreserveSign(V) : V;
+  double Result = NativeFP(Input.convertToDouble());
   if (llvm_fenv_testexcept()) {
     llvm_fenv_clearexcept();
     return nullptr;
   }
 
-  return GetConstantFoldFPValue(Result, Ty);
+  Constant *Output = GetConstantFoldFPValue(Result, Ty);
+  if (ShouldFTZPreservingSign) {
+    const auto *CFP = static_cast<ConstantFP *>(Output);
+    return ConstantFP::get(Ty->getContext(),
+                           FTZPreserveSign(CFP->getValueAPF()));
+  }
+  return Output;
 }
 
 #if defined(HAS_IEE754_FLOAT128) && defined(HAS_LOGF128)
@@ -2524,6 +2594,139 @@ static Constant *ConstantFoldScalarCall1(StringRef Name,
         return ConstantFoldFP(cosh, APF, Ty);
       case Intrinsic::sqrt:
         return ConstantFoldFP(sqrt, APF, Ty);
+
+      // NVVM Intrinsics:
+      case Intrinsic::nvvm_ceil_ftz_f:
+      case Intrinsic::nvvm_ceil_f:
+      case Intrinsic::nvvm_ceil_d:
+        return ConstantFoldFP(ceil, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+
+      case Intrinsic::nvvm_cos_approx_ftz_f:
+      case Intrinsic::nvvm_cos_approx_f:
+        return ConstantFoldFP(cos, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+
+      case Intrinsic::nvvm_ex2_approx_ftz_f:
+      case Intrinsic::nvvm_ex2_approx_d:
+      case Intrinsic::nvvm_ex2_approx_f:
+        return ConstantFoldFP(exp2, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+
+      case Intrinsic::nvvm_fabs_ftz_f:
+      case Intrinsic::nvvm_fabs_d:
+      case Intrinsic::nvvm_fabs_f:
+        return ConstantFoldFP(fabs, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+
+      case Intrinsic::nvvm_floor_ftz_f:
+      case Intrinsic::nvvm_floor_f:
+      case Intrinsic::nvvm_floor_d:
+        return ConstantFoldFP(floor, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+
+      case Intrinsic::nvvm_lg2_approx_ftz_f:
+      case Intrinsic::nvvm_lg2_approx_d:
+      case Intrinsic::nvvm_lg2_approx_f: {
+        if (APF.isNegative() || APF.isZero())
+          return nullptr;
+        return ConstantFoldFP(log2, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+      }
+
+      case Intrinsic::nvvm_rcp_rm_ftz_f:
+      case Intrinsic::nvvm_rcp_rn_ftz_f:
+      case Intrinsic::nvvm_rcp_rp_ftz_f:
+      case Intrinsic::nvvm_rcp_rz_ftz_f:
+      case Intrinsic::nvvm_rcp_approx_ftz_f:
+      case Intrinsic::nvvm_rcp_approx_ftz_d:
+      case Intrinsic::nvvm_rcp_rm_d:
+      case Intrinsic::nvvm_rcp_rm_f:
+      case Intrinsic::nvvm_rcp_rn_d:
+      case Intrinsic::nvvm_rcp_rn_f:
+      case Intrinsic::nvvm_rcp_rp_d:
+      case Intrinsic::nvvm_rcp_rp_f:
+      case Intrinsic::nvvm_rcp_rz_d:
+      case Intrinsic::nvvm_rcp_rz_f: {
+        APFloat::roundingMode RoundMode = nvvm::GetRCPRoundingMode(IntrinsicID);
+        bool IsApprox = nvvm::RCPIsApprox(IntrinsicID);
+        bool IsFTZ = nvvm::RCPShouldFTZ(IntrinsicID);
+
+        auto Denominator = IsFTZ ? FTZPreserveSign(APF) : APF;
+        if (IsApprox && Denominator.isZero()) {
+          // According to the PTX spec, approximate rcp should return infinity
+          // with the same sign as the denominator when dividing by 0.
+          APFloat Inf = APFloat::getInf(APF.getSemantics(), APF.isNegative());
+          return ConstantFP::get(Ty->getContext(), Inf);
+        }
+        APFloat Res = APFloat::getOne(APF.getSemantics());
+        APFloat::opStatus Status = Res.divide(Denominator, RoundMode);
+
+        if (Status == APFloat::opOK || Status == APFloat::opInexact) {
+          if (IsFTZ)
+            Res = FTZPreserveSign(Res);
+          return ConstantFP::get(Ty->getContext(), Res);
+        }
+        return nullptr;
+      }
+
+      case Intrinsic::nvvm_round_ftz_f:
+      case Intrinsic::nvvm_round_f:
+      case Intrinsic::nvvm_round_d:
+        return ConstantFoldFP(round, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+
+      case Intrinsic::nvvm_rsqrt_approx_ftz_d:
+      case Intrinsic::nvvm_rsqrt_approx_ftz_f:
+      case Intrinsic::nvvm_rsqrt_approx_d:
+      case Intrinsic::nvvm_rsqrt_approx_f: {
+        bool IsFTZ = nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID);
+        auto V = IsFTZ ? FTZPreserveSign(APF) : APF;
+        APFloat SqrtV(sqrt(V.convertToDouble()));
+
+        bool lost;
+        SqrtV.convert(APF.getSemantics(), APFloat::rmNearestTiesToEven, &lost);
+
+        APFloat Res = APFloat::getOne(APF.getSemantics());
+        Res.divide(SqrtV, APFloat::rmNearestTiesToEven);
+
+        // We do not need to flush the output for ftz because it is impossible
+        // for 1/sqrt(x) to be a denormal value. If x is the largest fp value,
+        // sqrt(x) will be a number with the exponent approximately halved and
+        // the reciprocal of that number can't be small enough to be denormal.
+        return ConstantFP::get(Ty->getContext(), Res);
+      }
+
+      case Intrinsic::nvvm_saturate_ftz_f:
+      case Intrinsic::nvvm_saturate_d:
+      case Intrinsic::nvvm_saturate_f: {
+        bool IsFTZ = nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID);
+        auto V = IsFTZ ? FTZPreserveSign(APF) : APF;
+        if (V.isNegative() || V.isZero() || V.isNaN())
+          return ConstantFP::getZero(Ty);
+        APFloat One = APFloat::getOne(APF.getSemantics());
+        if (V > One)
+          return ConstantFP::get(Ty->getContext(), One);
+        return ConstantFP::get(Ty->getContext(), APF);
+      }
+
+      case Intrinsic::nvvm_sin_approx_ftz_f:
+      case Intrinsic::nvvm_sin_approx_f:
+        return ConstantFoldFP(sin, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+
+      case Intrinsic::nvvm_sqrt_rn_ftz_f:
+      case Intrinsic::nvvm_sqrt_approx_ftz_f:
+      case Intrinsic::nvvm_sqrt_f:
+      case Intrinsic::nvvm_sqrt_rn_d:
+      case Intrinsic::nvvm_sqrt_rn_f:
+      case Intrinsic::nvvm_sqrt_approx_f:
+        if (APF.isNegative())
+          return nullptr;
+        return ConstantFoldFP(sqrt, APF, Ty,
+                              nvvm::UnaryMathIntrinsicShouldFTZ(IntrinsicID));
+
+      // AMDGCN Intrinsics:
       case Intrinsic::amdgcn_cos:
       case Intrinsic::amdgcn_sin: {
         double V = getValueAsDouble(Op);
diff --git a/llvm/test/Transforms/InstSimplify/const-fold-nvvm-unary-arithmetic.ll b/llvm/test/Transforms/InstSimplify/const-fold-nvvm-unary-arithmetic.ll
new file mode 100644
index 0000000000000..51563975be233
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/const-fold-nvvm-unary-arithmetic.ll
@@ -0,0 +1,1035 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt < %s -passes=instsimplify -march=nvptx64 -S | FileCheck %s
+
+; Test constant-folding for various NVVM unary arithmetic intrinsics.
+
+;###############################################################
+;#                          Ceil                               #
+;###############################################################
+
+define double @test_ceil_d_1_25() {
+; CHECK-LABEL: define double @test_ceil_d_1_25() {
+; CHECK-NEXT:    ret double 2.000000e+00
+;
+  %res = call double @llvm.nvvm.ceil.d(double 1.25)
+  ret double %res
+}
+
+define float @test_ceil_f_1_25() {
+; CHECK-LABEL: define float @test_ceil_f_1_25() {
+; CHECK-NEXT:    ret float 2.000000e+00
+;
+  %res = call float @llvm.nvvm.ceil.f(float 1.25)
+  ret float %res
+}
+
+define float @test_ceil_ftz_f_1_25() {
+; CHECK-LABEL: define float @test_ceil_ftz_f_1_25() {
+; CHECK-NEXT:    ret float 2.000000e+00
+;
+  %res = call float @llvm.nvvm.ceil.ftz.f(float 1.25)
+  ret float %res
+}
+
+define double @test_ceil_d_pos_subnorm() {
+; CHECK-LABEL: define double @test_ceil_d_pos_subnorm() {
+; CHECK-NEXT:    ret double 1.000000e+00
+;
+  %res = call double @llvm.nvvm.ceil.d(double 0x380FFFFFC0000000)
+  ret double %res
+}
+
+define float @test_ceil_f_pos_subnorm() {
+; CHECK-LABEL: define float @test_ceil_f_pos_subnorm() {
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %res = call float @llvm.nvvm.ceil.f(float 0x380FFFFFC0000000)
+  ret float %res
+}
+
+define float @test_ceil_ftz_f_pos_subnorm() {
+; CHECK-LABEL: define float @test_ceil_ftz_f_pos_subnorm() {
+; CHECK-NEXT:    ret float 0.000000e+00
+;
+  %res = call float @llvm.nvvm.ceil.ftz.f(float 0x380FFFFFC0000000)
+  ret float %res
+}
+
+;###############################################################
+;#                        Cos Approx                           #
+;###############################################################
+
+define float @test_cos_approx_f_1_25() {
+; CHECK-LABEL: define float @test_cos_approx_f_1_25() {
+; CHECK-NEXT:    ret float 0x3FD42E3DE0000000
+;
+  %res = call float @llvm.nvvm.cos.approx.f(float 1.25)
+  ret float %res
+}
+
+define float @test_cos_approx_ftz_f_1_25() {
+; CHECK-LABEL: define float @test_cos_approx_ftz_f_1_25() {
+; CHECK-NEXT:    ret float 0x3FD42E3DE0000000
+;
+  %res = call float @llvm.nvvm.cos.approx.ftz.f(float 1.25)
+  ret float %res
+}
+
+define float @test_cos_approx_f_pos_subnorm() {
+; CHECK-LABEL: define float @test_cos_approx_f_pos_subnorm() {
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %res = call float @llvm.nvvm.cos.approx.f(float 0x380FFFFFC0000000)
+  ret float %res
+}
+
+define float @test_cos_approx_ftz_f_pos_subnorm() {
+; CHECK-LABEL: define float @test_cos_approx_ftz_f_pos_subnorm() {
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %res = call float @llvm.nvvm.cos.approx.ftz.f(float 0x380FFFFFC0000000)
+  ret float %res
+}
+
+;###############################################################
+;#                        Ex2 Approx                           #
+;###############################################################
+
+define double @test_ex2_approx_d_1_25() {
+; CHECK-LABEL: define double @test_ex2_approx_d_1_25() {
+; CHECK-NEXT:    ret double 0x400306FE0A31B715
+;
+  %res = call double @llvm.nvvm.ex2.approx.d(double 1.25)
+  ret double %res
+}
+
+define float @test_ex2_approx_f_1_25() {
+; CHECK-LABEL: define float @test_ex2_approx_f_1_25() {
+; CHECK-NEXT:    ret float 0x400306FE00000000
+;
+  %res = call float @llvm.nvvm.ex2.approx.f(float 1.25)
+  ret float %res
+}
+
+define float @test_ex2_approx_ftz_f_1_25() {
+; CHECK-LABEL: define float @test_ex2_approx_ftz_f_1_25() {
+; CHECK-NEXT:    ret float 0x400306FE00000000
+;
+  %res = call float @llvm.nvvm.ex2.approx.ftz.f(float 1.25)
+  ret float %res
+}
+
+define double @test_ex2_approx_d_pos_subnorm() {
+; CHECK-LABEL: define double @test_ex2_approx_d_pos_subnorm() {
+; CHECK-NEXT:    ret double 1.000000e+00
+;
+  %res = call double @llvm.nvvm.ex2.approx.d(double 0x380FFFFFC0000000)
+  ret double %res
+}
+
+define float @test_ex2_approx_f_pos_subnorm() {
+; CHECK-LABEL: define float @test_ex2_approx_f_pos_subnorm() {
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %res = call float @llvm.nvvm.ex2.approx.f(float 0x380FFFFFC0000000)
+  ret float %res
+}
+
+define float @test_ex2_approx_ftz_f_pos_subnorm() {
+; CHECK-LABEL: define float @test_ex2_approx_ftz_f_pos_subnorm() {
+; CHECK-NEXT:    ret float 1.000000e+00
+;
+  %res = call float @llvm.nvvm.ex2.approx.ftz.f(float 0x380FFFFFC0000000)
+  ret float %res
+}
+
+;###############################################################
+;#                          FAbs                               #
+;###############################################################
+define double @test_fabs_d_neg_1_5() {
+; CHECK-LABEL: define double @test_fabs_d_neg_1_5() {
+; CHECK-NEXT:    ret double 1.500000e+00
+;
+  %res = call double @llvm.nvvm.fabs.d(double -1.5)
+  ret double %res
+}
+
+define float @test_fabs_f_neg_1_5() {
+; CHECK-LABEL: define float @test_fabs_f_neg_1_5() {
+; CHECK-NEXT:    ret float 1.500000e+00
+;
+  %res = call float @llvm.nvvm.fabs.f(float -1.5)
+  ret float %res
+}
+
+define float @test_fabs_ftz_f_neg_1_5() {
+; CHECK-LABEL: define float @test_fabs_ftz_f_neg_1_5() {
+; CHECK-NEXT:    ret float 1.500000e+00
+;
+  %res = call float @llvm.nvvm.fabs.ftz.f(float -1.5)
+  ret float %res
+}
+
+define double @test_fabs_d_1_25() {
+; CHECK-LABEL: define double @test_fabs_d_1_25() {
+; CHECK-NEXT:    ret double 1.250000e+00
+;
+  %res = call double @llvm.nvvm.fabs.d(double 1.25)
+  ret double %res
+}
+
+define float @test_fabs_f_1_25() {
+; CHECK-LABEL: define float @test_fabs_f_1_25() {
+; CHECK-NEXT:    ret float 1.250000e+00
+;
+  %res = call float @llvm.nvvm.fabs.f(float 1.25)
+  ret float %res
+}
+
+define float @test_fabs_ftz_f_1_25() {
+; CHECK-LABEL: define float @test_fabs_ftz_f_1_25() {
+; CHECK-NEXT:    ret float 1.250000e+00
+;
+  %res = call float @llvm.nvvm.fabs.ftz.f(float 1.25)
+  ret float %res
+}
+
+define double @test_fabs_d_neg_subnorm() {
+; CHECK-LABEL: define double @test_fabs_d_neg_subnorm() {
+; CHECK-NEXT:    ret double 0x380FFFFFC0000000
+;
+  %res = call double @llvm.nvvm.fabs.d(double 0xB80FFFFFC0000000)
+  ret double %res
+}
+
+define float @test_fabs_f_neg_subnorm() {
+; CHECK-LABEL: define float @test_fabs_f_neg_subnorm() {
+; CHECK-NEXT:    ret floa...
[truncated]

@LewisCrawford LewisCrawford force-pushed the fold_nvvm_math_intrinsics branch from a952b64 to 2338e8c Compare May 23, 2025 13:25
@LewisCrawford
Copy link
Contributor Author

I would like to land this patch: #140270 before merging this. That patch adds an option to disable constant-folding for FP function calls, and may be useful in the few cases where the device-side execution of these intrinsics must be a bitwise-exact match to the host-side constant-folded version of the results (or when debugging).

Add support for constant-folding numerous NVVM unary arithmetic
intrinsics (including f, d, and ftz_f variants):
  - nvvm.ceil.*
  - nvvm.cos.approx.*
  - nvvm.ex2.approx.*
  - nvvm.fabs.*
  - nvvm.floor.*
  - nvvm.lg2.approx.*
  - nvvm.rcp.*
  - nvvm.round.*
  - nvvm.rsqrt.approx.*
  - nvvm.saturate.*
  - nvvm.sin.approx.*
  - nvvm.sqrt.f
  - nvvm.sqrt.rn.*
  - nvvm.sqrt.approx.*
@LewisCrawford LewisCrawford force-pushed the fold_nvvm_math_intrinsics branch from 2338e8c to 9faec17 Compare May 23, 2025 16:20
@LewisCrawford LewisCrawford requested review from Artem-B and AlexMaclean and removed request for Artem-B May 23, 2025 16:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants