From 35553d452b32e9356352df8536fa0485207a9274 Mon Sep 17 00:00:00 2001 From: Emilio Cota Date: Sat, 23 Oct 2021 04:48:24 -0700 Subject: [PATCH] [mlir] Add polynomial approximation for vectorized math::Rsqrt This patch adds a polynomial approximation that matches the approximation in Eigen. Note that the approximation only applies to vectorized inputs; the scalar rsqrt is left unmodified. The approximation is protected with a flag since it emits an AVX2 intrinsic (generated via the X86Vector). This is the only reasonably clean way that I could find to generate the exact approximation that I wanted (i.e. an identical one to Eigen's). I considered two alternatives: 1. Introduce a Rsqrt intrinsic in LLVM, which doesn't exist yet. I believe this is because there is no definition of Rsqrt that all backends could agree on, since hardware instructions that implement it have widely varying degrees of precision. This is something that the standard could mandate, but Rsqrt is not part of IEEE754, so I don't think this option is feasible. 2. Emit fdiv(1.0, sqrt) with fast math flags to allow reciprocal transformations. Although portable, this doesn't allow us to generate exactly the code we want; it is the LLVM backend, and not MLIR, who controls what code is generated based on the target CPU. Reviewed By: ezhulenev Differential Revision: https://reviews.llvm.org/D112192 --- .../mlir/Dialect/Math/Transforms/Passes.h | 9 ++- .../Dialect/Math/Transforms/CMakeLists.txt | 1 + .../Transforms/PolynomialApproximation.cpp | 69 ++++++++++++++++++- .../Math/polynomial-approximation.mlir | 36 ++++++++++ mlir/test/lib/Dialect/Math/CMakeLists.txt | 1 + .../Math/TestPolynomialApproximation.cpp | 17 ++++- .../mlir-cpu-runner/X86Vector/lit.local.cfg | 5 ++ .../math_polynomial_approx_avx2.mlir | 40 +++++++++++ .../llvm-project-overlay/mlir/BUILD.bazel | 1 + .../mlir/test/BUILD.bazel | 1 + 10 files changed, 177 insertions(+), 3 deletions(-) create mode 100644 mlir/test/mlir-cpu-runner/X86Vector/lit.local.cfg create mode 100644 mlir/test/mlir-cpu-runner/X86Vector/math_polynomial_approx_avx2.mlir diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h index 4378f177fa0a1..8de5782fe9c9f 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -17,7 +17,14 @@ void populateExpandTanhPattern(RewritePatternSet &patterns); void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); -void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns); +struct MathPolynomialApproximationOptions { + // Enables the use of AVX2 intrinsics in some of the approximations. + bool enableAvx2 = false; +}; + +void populateMathPolynomialApproximationPatterns( + RewritePatternSet &patterns, + const MathPolynomialApproximationOptions &options = {}); } // namespace mlir diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt index cd17dc0fa27f6..c2182562fc244 100644 --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -13,4 +13,5 @@ add_mlir_dialect_library(MLIRMathTransforms MLIRPass MLIRStandard MLIRTransforms + MLIRX86Vector ) diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp index 7dddcfa45be79..3761b48c569e9 100644 --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Transforms/Bufferize.h" @@ -778,13 +779,79 @@ LogicalResult SinAndCosApproximation::matchAndRewrite( return success(); } +//----------------------------------------------------------------------------// +// Rsqrt approximation. +//----------------------------------------------------------------------------// + +namespace { +struct RsqrtApproximation : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::RsqrtOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult +RsqrtApproximation::matchAndRewrite(math::RsqrtOp op, + PatternRewriter &rewriter) const { + auto width = vectorWidth(op.operand().getType(), isF32); + // Only support already-vectorized rsqrt's. + if (!width.hasValue() || *width != 8) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, *width); + }; + + Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u)); + Value cstOnePointFive = bcast(f32Cst(builder, 1.5f)); + Value cstNegHalf = bcast(f32Cst(builder, -0.5f)); + Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u)); + + Value negHalf = builder.create(op.operand(), cstNegHalf); + + // Select only the inverse sqrt of positive normals (denormals are + // flushed to zero). + Value ltMinMask = builder.create(arith::CmpFPredicate::OLT, + op.operand(), cstMinNormPos); + Value infMask = builder.create(arith::CmpFPredicate::OEQ, + op.operand(), cstPosInf); + Value notNormalFiniteMask = builder.create(ltMinMask, infMask); + + // Compute an approximate result. + Value yApprox = builder.create(op.operand()); + + // Do a single step of Newton-Raphson iteration to improve the approximation. + // This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n). + // It is essential to evaluate the inner term like this because forming + // y_n^2 may over- or underflow. + Value inner = builder.create(negHalf, yApprox); + Value fma = builder.create(yApprox, inner, cstOnePointFive); + Value yNewton = builder.create(yApprox, fma); + + // Select the result of the Newton-Raphson step for positive normal arguments. + // For other arguments, choose the output of the intrinsic. This will + // return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if + // x is zero or a positive denormalized float (equivalent to flushing positive + // denormalized inputs to zero). + Value res = builder.create(notNormalFiniteMask, yApprox, yNewton); + rewriter.replaceOp(op, res); + + return success(); +} + //----------------------------------------------------------------------------// void mlir::populateMathPolynomialApproximationPatterns( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, + const MathPolynomialApproximationOptions &options) { patterns.add, SinAndCosApproximation>( patterns.getContext()); + if (options.enableAvx2) + patterns.add(patterns.getContext()); } diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir index bc6f39bd57034..9ba7c47bc5d47 100644 --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -1,4 +1,6 @@ // RUN: mlir-opt %s -test-math-polynomial-approximation | FileCheck %s +// RUN: mlir-opt %s -test-math-polynomial-approximation=enable-avx2 \ +// RUN: | FileCheck --check-prefix=AVX2 %s // Check that all math functions lowered to approximations built from // standard operations (add, mul, fma, shift, etc...). @@ -300,3 +302,37 @@ func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> { %0 = math.tanh %arg0 : vector<8xf32> return %0 : vector<8xf32> } + +// We only approximate rsqrt for vectors and when the AVX2 option is enabled. +// CHECK-LABEL: func @rsqrt_scalar +// AVX2-LABEL: func @rsqrt_scalar +// CHECK: math.rsqrt +// AVX2: math.rsqrt +func @rsqrt_scalar(%arg0: f32) -> f32 { + %0 = math.rsqrt %arg0 : f32 + return %0 : f32 +} + +// CHECK-LABEL: func @rsqrt_vector +// CHECK: math.rsqrt +// AVX2-LABEL: func @rsqrt_vector( +// AVX2-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { +// AVX2: %[[VAL_1:.*]] = arith.constant dense<0x7F800000> : vector<8xf32> +// AVX2: %[[VAL_2:.*]] = arith.constant dense<1.500000e+00> : vector<8xf32> +// AVX2: %[[VAL_3:.*]] = arith.constant dense<-5.000000e-01> : vector<8xf32> +// AVX2: %[[VAL_4:.*]] = arith.constant dense<1.17549435E-38> : vector<8xf32> +// AVX2: %[[VAL_5:.*]] = arith.mulf %[[VAL_0]], %[[VAL_3]] : vector<8xf32> +// AVX2: %[[VAL_6:.*]] = arith.cmpf olt, %[[VAL_0]], %[[VAL_4]] : vector<8xf32> +// AVX2: %[[VAL_7:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_1]] : vector<8xf32> +// AVX2: %[[VAL_8:.*]] = arith.ori %[[VAL_6]], %[[VAL_7]] : vector<8xi1> +// AVX2: %[[VAL_9:.*]] = x86vector.avx.rsqrt %[[VAL_0]] : vector<8xf32> +// AVX2: %[[VAL_10:.*]] = arith.mulf %[[VAL_5]], %[[VAL_9]] : vector<8xf32> +// AVX2: %[[VAL_11:.*]] = math.fma %[[VAL_9]], %[[VAL_10]], %[[VAL_2]] : vector<8xf32> +// AVX2: %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : vector<8xf32> +// AVX2: %[[VAL_13:.*]] = select %[[VAL_8]], %[[VAL_9]], %[[VAL_12]] : vector<8xi1>, vector<8xf32> +// AVX2: return %[[VAL_13]] : vector<8xf32> +// AVX2: } +func @rsqrt_vector(%arg0: vector<8xf32>) -> vector<8xf32> { + %0 = math.rsqrt %arg0 : vector<8xf32> + return %0 : vector<8xf32> +} diff --git a/mlir/test/lib/Dialect/Math/CMakeLists.txt b/mlir/test/lib/Dialect/Math/CMakeLists.txt index 64cae2f77c5e3..dd2f726928ae6 100644 --- a/mlir/test/lib/Dialect/Math/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Math/CMakeLists.txt @@ -11,4 +11,5 @@ add_mlir_library(MLIRMathTestPasses MLIRPass MLIRTransformUtils MLIRVector + MLIRX86Vector ) diff --git a/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp b/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp index c035eab7ee527..d1aa7796437f7 100644 --- a/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp +++ b/mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -23,10 +24,16 @@ using namespace mlir; namespace { struct TestMathPolynomialApproximationPass : public PassWrapper { + TestMathPolynomialApproximationPass() = default; + TestMathPolynomialApproximationPass( + const TestMathPolynomialApproximationPass &pass) {} + void runOnFunction() override; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + if (enableAvx2) + registry.insert(); } StringRef getArgument() const final { return "test-math-polynomial-approximation"; @@ -34,12 +41,20 @@ struct TestMathPolynomialApproximationPass StringRef getDescription() const final { return "Test math polynomial approximations"; } + + Option enableAvx2{ + *this, "enable-avx2", + llvm::cl::desc("Enable approximations that emit AVX2 intrinsics via the " + "X86Vector dialect"), + llvm::cl::init(false)}; }; } // end anonymous namespace void TestMathPolynomialApproximationPass::runOnFunction() { RewritePatternSet patterns(&getContext()); - populateMathPolynomialApproximationPatterns(patterns); + MathPolynomialApproximationOptions approx_options; + approx_options.enableAvx2 = enableAvx2; + populateMathPolynomialApproximationPatterns(patterns, approx_options); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/test/mlir-cpu-runner/X86Vector/lit.local.cfg b/mlir/test/mlir-cpu-runner/X86Vector/lit.local.cfg new file mode 100644 index 0000000000000..88d1d75f7df8b --- /dev/null +++ b/mlir/test/mlir-cpu-runner/X86Vector/lit.local.cfg @@ -0,0 +1,5 @@ +import sys + +# X86Vector tests must be enabled via build flag. +if not config.mlir_run_x86vector_tests: + config.unsupported = True diff --git a/mlir/test/mlir-cpu-runner/X86Vector/math_polynomial_approx_avx2.mlir b/mlir/test/mlir-cpu-runner/X86Vector/math_polynomial_approx_avx2.mlir new file mode 100644 index 0000000000000..ce7a6550dbecf --- /dev/null +++ b/mlir/test/mlir-cpu-runner/X86Vector/math_polynomial_approx_avx2.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt %s -test-math-polynomial-approximation="enable-avx2" \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-vector-to-llvm="enable-x86vector" \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-std-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main -entry-point-result=void -O0 \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: | FileCheck %s + +// -------------------------------------------------------------------------- // +// rsqrt. +// -------------------------------------------------------------------------- // + +func @rsqrt() { + // Sanity-check that the scalar rsqrt still works OK. + // CHECK: inf + %0 = arith.constant 0.0 : f32 + %rsqrt_0 = math.rsqrt %0 : f32 + vector.print %rsqrt_0 : f32 + // CHECK: 0.707107 + %two = arith.constant 2.0: f32 + %rsqrt_two = math.rsqrt %two : f32 + vector.print %rsqrt_two : f32 + + // Check that the vectorized approximation is reasonably accurate. + // CHECK: 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107 + %vec8 = arith.constant dense<2.0> : vector<8xf32> + %rsqrt_vec8 = math.rsqrt %vec8 : vector<8xf32> + vector.print %rsqrt_vec8 : vector<8xf32> + + return +} + +func @main() { + call @rsqrt(): () -> () + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 70c4130c716bc..fea353d460cb9 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7057,6 +7057,7 @@ cc_library( ":Support", ":Transforms", ":VectorOps", + ":X86Vector", "//llvm:Support", ], ) diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel index 8e831e142fcc2..f776696ade5ba 100644 --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -406,6 +406,7 @@ cc_library( "//mlir:Pass", "//mlir:TransformUtils", "//mlir:VectorOps", + "//mlir:X86Vector", ], )