forked from llvm/llvm-project
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir] Add polynomial approximation for vectorized math::Rsqrt
This patch adds a polynomial approximation that matches the approximation in Eigen. Note that the approximation only applies to vectorized inputs; the scalar rsqrt is left unmodified. The approximation is protected with a flag since it emits an AVX2 intrinsic (generated via the X86Vector). This is the only reasonably clean way that I could find to generate the exact approximation that I wanted (i.e. an identical one to Eigen's). I considered two alternatives: 1. Introduce a Rsqrt intrinsic in LLVM, which doesn't exist yet. I believe this is because there is no definition of Rsqrt that all backends could agree on, since hardware instructions that implement it have widely varying degrees of precision. This is something that the standard could mandate, but Rsqrt is not part of IEEE754, so I don't think this option is feasible. 2. Emit fdiv(1.0, sqrt) with fast math flags to allow reciprocal transformations. Although portable, this doesn't allow us to generate exactly the code we want; it is the LLVM backend, and not MLIR, who controls what code is generated based on the target CPU. Reviewed By: ezhulenev Differential Revision: https://reviews.llvm.org/D112192
- Loading branch information
Showing
10 changed files
with
177 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,4 +13,5 @@ add_mlir_dialect_library(MLIRMathTransforms | |
MLIRPass | ||
MLIRStandard | ||
MLIRTransforms | ||
MLIRX86Vector | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,4 +11,5 @@ add_mlir_library(MLIRMathTestPasses | |
MLIRPass | ||
MLIRTransformUtils | ||
MLIRVector | ||
MLIRX86Vector | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import sys | ||
|
||
# X86Vector tests must be enabled via build flag. | ||
if not config.mlir_run_x86vector_tests: | ||
config.unsupported = True |
40 changes: 40 additions & 0 deletions
40
mlir/test/mlir-cpu-runner/X86Vector/math_polynomial_approx_avx2.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// RUN: mlir-opt %s -test-math-polynomial-approximation="enable-avx2" \ | ||
// RUN: -convert-arith-to-llvm \ | ||
// RUN: -convert-vector-to-llvm="enable-x86vector" \ | ||
// RUN: -convert-math-to-llvm \ | ||
// RUN: -convert-std-to-llvm \ | ||
// RUN: -reconcile-unrealized-casts \ | ||
// RUN: | mlir-cpu-runner \ | ||
// RUN: -e main -entry-point-result=void -O0 \ | ||
// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext \ | ||
// RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \ | ||
// RUN: | FileCheck %s | ||
|
||
// -------------------------------------------------------------------------- // | ||
// rsqrt. | ||
// -------------------------------------------------------------------------- // | ||
|
||
func @rsqrt() { | ||
// Sanity-check that the scalar rsqrt still works OK. | ||
// CHECK: inf | ||
%0 = arith.constant 0.0 : f32 | ||
%rsqrt_0 = math.rsqrt %0 : f32 | ||
vector.print %rsqrt_0 : f32 | ||
// CHECK: 0.707107 | ||
%two = arith.constant 2.0: f32 | ||
%rsqrt_two = math.rsqrt %two : f32 | ||
vector.print %rsqrt_two : f32 | ||
|
||
// Check that the vectorized approximation is reasonably accurate. | ||
// CHECK: 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107, 0.707107 | ||
%vec8 = arith.constant dense<2.0> : vector<8xf32> | ||
%rsqrt_vec8 = math.rsqrt %vec8 : vector<8xf32> | ||
vector.print %rsqrt_vec8 : vector<8xf32> | ||
|
||
return | ||
} | ||
|
||
func @main() { | ||
call @rsqrt(): () -> () | ||
return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7057,6 +7057,7 @@ cc_library( | |
":Support", | ||
":Transforms", | ||
":VectorOps", | ||
":X86Vector", | ||
"//llvm:Support", | ||
], | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -406,6 +406,7 @@ cc_library( | |
"//mlir:Pass", | ||
"//mlir:TransformUtils", | ||
"//mlir:VectorOps", | ||
"//mlir:X86Vector", | ||
], | ||
) | ||
|
||
|