Skip to content

Commit

Permalink
[mlir] Add polynomial approximation for vectorized math::Rsqrt
Browse files Browse the repository at this point in the history
This patch adds a polynomial approximation that matches the
approximation in Eigen.

Note that the approximation only applies to vectorized inputs;
the scalar rsqrt is left unmodified.

The approximation is protected with a flag since it emits an AVX2
intrinsic (generated via the X86Vector). This is the only reasonably
clean way that I could find to generate the exact approximation that
I wanted (i.e. an identical one to Eigen's).

I considered two alternatives:

1. Introduce a Rsqrt intrinsic in LLVM, which doesn't exist yet.
   I believe this is because there is no definition of Rsqrt that
   all backends could agree on, since hardware instructions that
   implement it have widely varying degrees of precision.
   This is something that the standard could mandate, but Rsqrt is
   not part of IEEE754, so I don't think this option is feasible.

2. Emit fdiv(1.0, sqrt) with fast math flags to allow reciprocal
   transformations. Although portable, this doesn't allow us
   to generate exactly the code we want; it is the LLVM backend,
   and not MLIR, who controls what code is generated based on the
   target CPU.

Reviewed By: ezhulenev

Differential Revision: https://reviews.llvm.org/D112192
  • Loading branch information
cota authored and ezhulenev committed Oct 23, 2021
1 parent c534835 commit 35553d4
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 3 deletions.
9 changes: 8 additions & 1 deletion mlir/include/mlir/Dialect/Math/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@ void populateExpandTanhPattern(RewritePatternSet &patterns);

void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);

void populateMathPolynomialApproximationPatterns(RewritePatternSet &patterns);
struct MathPolynomialApproximationOptions {
// Enables the use of AVX2 intrinsics in some of the approximations.
bool enableAvx2 = false;
};

void populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options = {});

} // namespace mlir

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Math/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ add_mlir_dialect_library(MLIRMathTransforms
MLIRPass
MLIRStandard
MLIRTransforms
MLIRX86Vector
)
69 changes: 68 additions & 1 deletion mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Transforms/Bufferize.h"
Expand Down Expand Up @@ -778,13 +779,79 @@ LogicalResult SinAndCosApproximation<isSine, OpTy>::matchAndRewrite(
return success();
}

//----------------------------------------------------------------------------//
// Rsqrt approximation.
//----------------------------------------------------------------------------//

namespace {
struct RsqrtApproximation : public OpRewritePattern<math::RsqrtOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(math::RsqrtOp op,
PatternRewriter &rewriter) const final;
};
} // namespace

LogicalResult
RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
PatternRewriter &rewriter) const {
auto width = vectorWidth(op.operand().getType(), isF32);
// Only support already-vectorized rsqrt's.
if (!width.hasValue() || *width != 8)
return rewriter.notifyMatchFailure(op, "unsupported operand type");

ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, *width);
};

Value cstPosInf = bcast(f32FromBits(builder, 0x7f800000u));
Value cstOnePointFive = bcast(f32Cst(builder, 1.5f));
Value cstNegHalf = bcast(f32Cst(builder, -0.5f));
Value cstMinNormPos = bcast(f32FromBits(builder, 0x00800000u));

Value negHalf = builder.create<arith::MulFOp>(op.operand(), cstNegHalf);

// Select only the inverse sqrt of positive normals (denormals are
// flushed to zero).
Value ltMinMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT,
op.operand(), cstMinNormPos);
Value infMask = builder.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ,
op.operand(), cstPosInf);
Value notNormalFiniteMask = builder.create<arith::OrIOp>(ltMinMask, infMask);

// Compute an approximate result.
Value yApprox = builder.create<x86vector::RsqrtOp>(op.operand());

// Do a single step of Newton-Raphson iteration to improve the approximation.
// This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
// It is essential to evaluate the inner term like this because forming
// y_n^2 may over- or underflow.
Value inner = builder.create<arith::MulFOp>(negHalf, yApprox);
Value fma = builder.create<math::FmaOp>(yApprox, inner, cstOnePointFive);
Value yNewton = builder.create<arith::MulFOp>(yApprox, fma);

// Select the result of the Newton-Raphson step for positive normal arguments.
// For other arguments, choose the output of the intrinsic. This will
// return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
// x is zero or a positive denormalized float (equivalent to flushing positive
// denormalized inputs to zero).
Value res = builder.create<SelectOp>(notNormalFiniteMask, yApprox, yNewton);
rewriter.replaceOp(op, res);

return success();
}

//----------------------------------------------------------------------------//

void mlir::populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns) {
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options) {
patterns.add<TanhApproximation, LogApproximation, Log2Approximation,
Log1pApproximation, ExpApproximation, ExpM1Approximation,
SinAndCosApproximation<true, math::SinOp>,
SinAndCosApproximation<false, math::CosOp>>(
patterns.getContext());
if (options.enableAvx2)
patterns.add<RsqrtApproximation>(patterns.getContext());
}
36 changes: 36 additions & 0 deletions mlir/test/Dialect/Math/polynomial-approximation.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// RUN: mlir-opt %s -test-math-polynomial-approximation | FileCheck %s
// RUN: mlir-opt %s -test-math-polynomial-approximation=enable-avx2 \
// RUN: | FileCheck --check-prefix=AVX2 %s

// Check that all math functions lowered to approximations built from
// standard operations (add, mul, fma, shift, etc...).
Expand Down Expand Up @@ -300,3 +302,37 @@ func @tanh_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
%0 = math.tanh %arg0 : vector<8xf32>
return %0 : vector<8xf32>
}

// We only approximate rsqrt for vectors and when the AVX2 option is enabled.
// CHECK-LABEL: func @rsqrt_scalar
// AVX2-LABEL: func @rsqrt_scalar
// CHECK: math.rsqrt
// AVX2: math.rsqrt
func @rsqrt_scalar(%arg0: f32) -> f32 {
%0 = math.rsqrt %arg0 : f32
return %0 : f32
}

// CHECK-LABEL: func @rsqrt_vector
// CHECK: math.rsqrt
// AVX2-LABEL: func @rsqrt_vector(
// AVX2-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> {
// AVX2: %[[VAL_1:.*]] = arith.constant dense<0x7F800000> : vector<8xf32>
// AVX2: %[[VAL_2:.*]] = arith.constant dense<1.500000e+00> : vector<8xf32>
// AVX2: %[[VAL_3:.*]] = arith.constant dense<-5.000000e-01> : vector<8xf32>
// AVX2: %[[VAL_4:.*]] = arith.constant dense<1.17549435E-38> : vector<8xf32>
// AVX2: %[[VAL_5:.*]] = arith.mulf %[[VAL_0]], %[[VAL_3]] : vector<8xf32>
// AVX2: %[[VAL_6:.*]] = arith.cmpf olt, %[[VAL_0]], %[[VAL_4]] : vector<8xf32>
// AVX2: %[[VAL_7:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_1]] : vector<8xf32>
// AVX2: %[[VAL_8:.*]] = arith.ori %[[VAL_6]], %[[VAL_7]] : vector<8xi1>
// AVX2: %[[VAL_9:.*]] = x86vector.avx.rsqrt %[[VAL_0]] : vector<8xf32>
// AVX2: %[[VAL_10:.*]] = arith.mulf %[[VAL_5]], %[[VAL_9]] : vector<8xf32>
// AVX2: %[[VAL_11:.*]] = math.fma %[[VAL_9]], %[[VAL_10]], %[[VAL_2]] : vector<8xf32>
// AVX2: %[[VAL_12:.*]] = arith.mulf %[[VAL_9]], %[[VAL_11]] : vector<8xf32>
// AVX2: %[[VAL_13:.*]] = select %[[VAL_8]], %[[VAL_9]], %[[VAL_12]] : vector<8xi1>, vector<8xf32>
// AVX2: return %[[VAL_13]] : vector<8xf32>
// AVX2: }
func @rsqrt_vector(%arg0: vector<8xf32>) -> vector<8xf32> {
%0 = math.rsqrt %arg0 : vector<8xf32>
return %0 : vector<8xf32>
}
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ add_mlir_library(MLIRMathTestPasses
MLIRPass
MLIRTransformUtils
MLIRVector
MLIRX86Vector
)
17 changes: 16 additions & 1 deletion mlir/test/lib/Dialect/Math/TestPolynomialApproximation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

Expand All @@ -23,23 +24,37 @@ using namespace mlir;
namespace {
struct TestMathPolynomialApproximationPass
: public PassWrapper<TestMathPolynomialApproximationPass, FunctionPass> {
TestMathPolynomialApproximationPass() = default;
TestMathPolynomialApproximationPass(
const TestMathPolynomialApproximationPass &pass) {}

void runOnFunction() override;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithmeticDialect, math::MathDialect,
vector::VectorDialect>();
if (enableAvx2)
registry.insert<x86vector::X86VectorDialect>();
}
StringRef getArgument() const final {
return "test-math-polynomial-approximation";
}
StringRef getDescription() const final {
return "Test math polynomial approximations";
}

Option<bool> enableAvx2{
*this, "enable-avx2",
llvm::cl::desc("Enable approximations that emit AVX2 intrinsics via the "
"X86Vector dialect"),
llvm::cl::init(false)};
};
} // end anonymous namespace

void TestMathPolynomialApproximationPass::runOnFunction() {
RewritePatternSet patterns(&getContext());
populateMathPolynomialApproximationPatterns(patterns);
MathPolynomialApproximationOptions approx_options;
approx_options.enableAvx2 = enableAvx2;
populateMathPolynomialApproximationPatterns(patterns, approx_options);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}

Expand Down
5 changes: 5 additions & 0 deletions mlir/test/mlir-cpu-runner/X86Vector/lit.local.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import sys

# X86Vector tests must be enabled via build flag.
if not config.mlir_run_x86vector_tests:
config.unsupported = True
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: mlir-opt %s -test-math-polynomial-approximation="enable-avx2" \
// RUN: -convert-arith-to-llvm \
// RUN: -convert-vector-to-llvm="enable-x86vector" \
// RUN: -convert-math-to-llvm \
// RUN: -convert-std-to-llvm \
// RUN: -reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner \
// RUN: -e main -entry-point-result=void -O0 \
// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \
// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \
// RUN: | FileCheck %s

// -------------------------------------------------------------------------- //
// rsqrt.
// -------------------------------------------------------------------------- //

func @rsqrt() {
// Sanity-check that the scalar rsqrt still works OK.
// CHECK: inf
%0 = arith.constant 0.0 : f32
%rsqrt_0 = math.rsqrt %0 : f32
vector.print %rsqrt_0 : f32
// CHECK: 0.707107
%two = arith.constant 2.0: f32
%rsqrt_two = math.rsqrt %two : f32
vector.print %rsqrt_two : f32

// Check that the vectorized approximation is reasonably accurate.
// CHECK: 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107
%vec8 = arith.constant dense<2.0> : vector<8xf32>
%rsqrt_vec8 = math.rsqrt %vec8 : vector<8xf32>
vector.print %rsqrt_vec8 : vector<8xf32>

return
}

func @main() {
call @rsqrt(): () -> ()
return
}
1 change: 1 addition & 0 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -7057,6 +7057,7 @@ cc_library(
":Support",
":Transforms",
":VectorOps",
":X86Vector",
"//llvm:Support",
],
)
Expand Down
1 change: 1 addition & 0 deletions utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ cc_library(
"//mlir:Pass",
"//mlir:TransformUtils",
"//mlir:VectorOps",
"//mlir:X86Vector",
],
)

Expand Down

0 comments on commit 35553d4

Please sign in to comment.