Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast atan and atan2 functions. #8388

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
69052fa
Fix for the removed DataLayout constructor.
mcourteaux Aug 13, 2024
e82d9ff
Fast vectorizable atan and atan2 functions.
mcourteaux Aug 10, 2024
11b442c
Default to not using fast atan versions if on CUDA.
mcourteaux Aug 10, 2024
dee28bc
Finished fast atan/atan2 functions and tests.
mcourteaux Aug 10, 2024
362f0ea
Correct attribution.
mcourteaux Aug 10, 2024
1bd7f7a
Clang-format
mcourteaux Aug 10, 2024
4f1e851
Weird WebAssembly limits...
mcourteaux Aug 11, 2024
f10396b
Small improvements to the optimization script.
mcourteaux Aug 11, 2024
de9d3b7
Polynomial optimization for log, exp, sin, cos with correct ranges.
mcourteaux Aug 11, 2024
d8e3225
Improve fast atan performance tests for GPU.
mcourteaux Aug 12, 2024
3bcd1a7
Bugfix fast_atan approximation. Fix correctness test to exceed the ra…
mcourteaux Aug 12, 2024
2aa0c7e
Cleanup
mcourteaux Aug 12, 2024
fd088f8
Enum class instead of enum for ApproximationPrecision.
mcourteaux Aug 12, 2024
62534d7
Weird Metal limits. There should be a better way...
mcourteaux Aug 12, 2024
c76e719
Skip test for WebGPU.
mcourteaux Aug 12, 2024
fc25944
Fast atan/atan2 polynomials reoptimized. New optimization strategy: ULP.
mcourteaux Aug 13, 2024
b5d0cad
Feedback Steven.
mcourteaux Aug 13, 2024
4d61c6a
More comments and test mantissa error.
mcourteaux Aug 14, 2024
ff28b99
Do not error when testing arctan performance on Metal / WebGPU.
mcourteaux Aug 14, 2024
5a435f0
Partially apply clang-tidy fixes we don't enforce yet (#8376)
abadams Aug 16, 2024
a4544be
Fix bundling error on buildbots (#8392)
alexreinking Aug 16, 2024
624f737
Fix incorrect std::array sizes in Target.cpp (#8396)
steven-johnson Aug 23, 2024
5ca88b7
Fix _Float16 detection on ARM64 GCC<13 (#8401)
alexreinking Aug 29, 2024
238f73c
Update README.md (#8404)
abadams Sep 2, 2024
b09f611
Support CMAKE_OSX_ARCHITECTURES (#8390)
alexreinking Sep 4, 2024
0614530
Pip packaging at last! (#8405)
alexreinking Sep 4, 2024
ae6dac4
Big documentation update (#8410)
alexreinking Sep 5, 2024
30b5938
Document how to find Halide from a pip installation (#8411)
alexreinking Sep 6, 2024
6f0da12
Merge pull request #8412
alexreinking Sep 6, 2024
44651f9
Fix classifier spelling (#8413)
alexreinking Sep 7, 2024
636ad8f
Make run-clang-tidy.sh work on macOS (#8416)
alexreinking Sep 9, 2024
51824df
Link to PyPI from Doxygen index.html (#8415)
alexreinking Sep 9, 2024
c9b2a76
Include our Markdown documentation in the Doxygen site. (#8417)
alexreinking Sep 10, 2024
a8966e9
Add missing backslash (#8419)
abadams Sep 15, 2024
9bcb9b7
Reschedule the matrix multiply performance app (#8418)
abadams Sep 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
More comments and test mantissa error.
  • Loading branch information
mcourteaux committed Aug 14, 2024
commit 4d61c6a1398fb321f55aae498c21e5ca68ab9f8c
68 changes: 47 additions & 21 deletions src/IROperator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1447,69 +1447,95 @@ Expr fast_atan_approximation(const Expr &x_full, ApproximationPrecision precisio
switch (precision) {
mcourteaux marked this conversation as resolved.
Show resolved Hide resolved
// == MSE Optimized == //
case ApproximationPrecision::MSE_Poly2: // (MSE=1.0264e-05, MAE=9.2149e-03, MaxUlpE=3.9855e+05)
c = {+9.762134539879e-01f, -2.000301999499e-01f}; break;
c = {+9.762134539879e-01f, -2.000301999499e-01f};
break;
case ApproximationPrecision::MSE_Poly3: // (MSE=1.5776e-07, MAE=1.3239e-03, MaxUlpE=6.7246e+04)
c = {+9.959820734941e-01f, -2.922781275652e-01f, +8.301806798764e-02f}; break;
c = {+9.959820734941e-01f, -2.922781275652e-01f, +8.301806798764e-02f};
break;
case ApproximationPrecision::MSE_Poly4: // (MSE=2.8490e-09, MAE=1.9922e-04, MaxUlpE=1.1422e+04)
c = {+9.993165406918e-01f, -3.222865011143e-01f, +1.490324612527e-01f, -4.086355921512e-02f}; break;
c = {+9.993165406918e-01f, -3.222865011143e-01f, +1.490324612527e-01f, -4.086355921512e-02f};
break;
case ApproximationPrecision::MSE_Poly5: // (MSE=5.6675e-11, MAE=3.0801e-05, MaxUlpE=1.9456e+03)
c = {+9.998833730470e-01f, -3.305995351168e-01f, +1.814513158372e-01f, -8.717338298570e-02f,
+2.186719361787e-02f}; break;
+2.186719361787e-02f};
break;
case ApproximationPrecision::MSE_Poly6: // (MSE=1.2027e-12, MAE=4.8469e-06, MaxUlpE=3.3187e+02)
c = {+9.999800646964e-01f, -3.326943930673e-01f, +1.940196968486e-01f, -1.176947321238e-01f,
+5.408220801540e-02f, -1.229952788751e-02f}; break;
+5.408220801540e-02f, -1.229952788751e-02f};
break;
case ApproximationPrecision::MSE_Poly7: // (MSE=2.6729e-14, MAE=7.7227e-07, MaxUlpE=5.6646e+01)
c = {+9.999965889517e-01f, -3.331900904961e-01f, +1.982328680483e-01f, -1.329414694644e-01f,
+8.076237117606e-02f, -3.461248530394e-02f, +7.151152759080e-03f}; break;
+8.076237117606e-02f, -3.461248530394e-02f, +7.151152759080e-03f};
break;
case ApproximationPrecision::MSE_Poly8: // (MSE=6.1506e-16, MAE=1.2419e-07, MaxUlpE=9.6914e+00)
c = {+9.999994159669e-01f, -3.333022219271e-01f, +1.995110884308e-01f, -1.393321817395e-01f,
+9.709319573480e-02f, -5.688043380309e-02f, +2.256648487698e-02f, -4.257308331872e-03f}; break;
+9.709319573480e-02f, -5.688043380309e-02f, +2.256648487698e-02f, -4.257308331872e-03f};
break;

// == MAE Optimized == //
case ApproximationPrecision::MAE_1e_2:
case ApproximationPrecision::MAE_Poly2: // (MSE=1.2096e-05, MAE=4.9690e-03, MaxUlpE=4.6233e+05)
c = {+9.724104536788e-01f, -1.919812827495e-01f}; break;
c = {+9.724104536788e-01f, -1.919812827495e-01f};
break;
case ApproximationPrecision::MAE_1e_3:
case ApproximationPrecision::MAE_Poly3: // (MSE=1.8394e-07, MAE=6.1071e-04, MaxUlpE=7.7667e+04)
c = {+9.953600796593e-01f, -2.887020515559e-01f, +7.935084373856e-02f}; break;
c = {+9.953600796593e-01f, -2.887020515559e-01f, +7.935084373856e-02f};
break;
case ApproximationPrecision::MAE_1e_4:
case ApproximationPrecision::MAE_Poly4: // (MSE=3.2969e-09, MAE=8.1642e-05, MaxUlpE=1.3136e+04)
c = {+9.992141075707e-01f, -3.211780734117e-01f, +1.462720063085e-01f, -3.899151874271e-02f}; break;
c = {+9.992141075707e-01f, -3.211780734117e-01f, +1.462720063085e-01f, -3.899151874271e-02f};
break;
case ApproximationPrecision::MAE_Poly5: // (MSE=6.5235e-11, MAE=1.1475e-05, MaxUlpE=2.2296e+03)
c = {+9.998663727249e-01f, -3.303055171903e-01f, +1.801624340886e-01f, -8.516115366058e-02f,
+2.084750202717e-02f}; break;
+2.084750202717e-02f};
break;
case ApproximationPrecision::MAE_1e_5:
case ApproximationPrecision::MAE_Poly6: // (MSE=1.3788e-12, MAE=1.6673e-06, MaxUlpE=3.7921e+02)
c = {+9.999772256973e-01f, -3.326229914097e-01f, +1.935414518077e-01f, -1.164292778405e-01f,
+5.265046001895e-02f, -1.172037220425e-02f}; break;
+5.265046001895e-02f, -1.172037220425e-02f};
break;
case ApproximationPrecision::MAE_1e_6:
case ApproximationPrecision::MAE_Poly7: // (MSE=3.0551e-14, MAE=2.4809e-07, MaxUlpE=6.4572e+01)
c = {+9.999961125922e-01f, -3.331737159104e-01f, +1.980784841430e-01f, -1.323346922675e-01f,
+7.962601662878e-02f, -3.360626486524e-02f, +6.812471171209e-03f}; break;
+7.962601662878e-02f, -3.360626486524e-02f, +6.812471171209e-03f};
break;
case ApproximationPrecision::MAE_Poly8: // (MSE=7.0132e-16, MAE=3.7579e-08, MaxUlpE=1.1023e+01)
c = {+9.999993357462e-01f, -3.332986153129e-01f, +1.994657492754e-01f, -1.390867909988e-01f,
+9.642330770840e-02f, -5.591422536378e-02f, +2.186431903729e-02f, -4.054954273090e-03f}; break;
+9.642330770840e-02f, -5.591422536378e-02f, +2.186431903729e-02f, -4.054954273090e-03f};
break;


// == Max ULP Optimized == //
case ApproximationPrecision::MULPE_Poly2: // (MSE=2.1006e-05, MAE=1.0755e-02, MaxUlpE=1.8221e+05)
c = {+9.891111216318e-01f, -2.144680385336e-01f}; break;
c = {+9.891111216318e-01f, -2.144680385336e-01f};
break;
case ApproximationPrecision::MULPE_1e_2:
case ApproximationPrecision::MULPE_Poly3: // (MSE=3.5740e-07, MAE=1.3164e-03, MaxUlpE=2.2273e+04)
c = {+9.986650768126e-01f, -3.029909865833e-01f, +9.104044335898e-02f}; break;
c = {+9.986650768126e-01f, -3.029909865833e-01f, +9.104044335898e-02f};
break;
case ApproximationPrecision::MULPE_1e_3:
case ApproximationPrecision::MULPE_Poly4: // (MSE=6.4750e-09, MAE=1.5485e-04, MaxUlpE=2.6199e+03)
c = {+9.998421981586e-01f, -3.262726405770e-01f, +1.562944595469e-01f, -4.462070448745e-02f}; break;
c = {+9.998421981586e-01f, -3.262726405770e-01f, +1.562944595469e-01f, -4.462070448745e-02f};
break;
case ApproximationPrecision::MULPE_1e_4:
case ApproximationPrecision::MULPE_Poly5: // (MSE=1.3135e-10, MAE=2.5335e-05, MaxUlpE=4.2948e+02)
c = {+9.999741103798e-01f, -3.318237821017e-01f, +1.858860952571e-01f, -9.300240079057e-02f,
+2.438947597681e-02f}; break;
+2.438947597681e-02f};
break;
case ApproximationPrecision::MULPE_1e_5:
case ApproximationPrecision::MULPE_Poly6: // (MSE=3.0079e-12, MAE=3.5307e-06, MaxUlpE=5.9838e+01)
c = {+9.999963876702e-01f, -3.330364633925e-01f, +1.959597060284e-01f, -1.220687452250e-01f,
+5.834036471395e-02f, -1.379661708254e-02f}; break;
+5.834036471395e-02f, -1.379661708254e-02f};
break;
case ApproximationPrecision::MULPE_1e_6:
case ApproximationPrecision::MULPE_Poly7: // (MSE=6.3489e-14, MAE=4.8826e-07, MaxUlpE=8.2764e+00)
c = {+9.999994992400e-01f, -3.332734078379e-01f, +1.988954540598e-01f, -1.351537940907e-01f,
+8.431852775558e-02f, -3.734345976535e-02f, +7.955832300869e-03f}; break;
+8.431852775558e-02f, -3.734345976535e-02f, +7.955832300869e-03f};
break;
case ApproximationPrecision::MULPE_Poly8: // (MSE=1.3696e-15, MAE=7.5850e-08, MaxUlpE=1.2850e+00)
c = {+9.999999220612e-01f, -3.333208398432e-01f, +1.997085632112e-01f, -1.402570625577e-01f,
+9.930940122930e-02f, -5.971380457112e-02f, +2.440561807586e-02f, -4.733710058459e-03f}; break;
+9.930940122930e-02f, -5.971380457112e-02f, +2.440561807586e-02f, -4.733710058459e-03f};
break;
}
// clang-format on

Expand Down
53 changes: 42 additions & 11 deletions src/IROperator.h
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,24 @@ Expr fast_sin(const Expr &x);
Expr fast_cos(const Expr &x);
// @}

/**
* Enum that declares several options for functions that are approximated
* by polynomial expansions. These polynomials can be optimized for three
* different metrics: Mean Squared Error, Maximum Absolute Error, or
* Maximum Units in Last Place (ULP) Error.
*
* Orthogonally to the optimization objective, these polynomials can vary
* in degree. Higher degree polynomials will give more precise results.
* Note that the `X` in the `PolyX` enum values refer to the number of terms
* in the polynomial, and not the degree of the polynomial. E.g., even
* symmetric functions may be implemented using only even powers, for which
* `Poly3` would actually mean that terms in [1, x^2, x^4] are used.
*
* Additionally, if you don't care about number of terms in the polynomial
* and you do care about the maximal absolute error the approximation may have
* over the domain, you may use the `MAE_1e_x` values and the implementation
* will decide the appropriate polynomial degree that achieves this precision.
*/
enum class ApproximationPrecision {
/** Mean Squared Error Optimized. */
// @{
Expand All @@ -984,15 +1002,6 @@ enum class ApproximationPrecision {
MSE_Poly8,
// @}

/* Maximum Absolute Error Optimized. */
// @{
MAE_1e_2,
MAE_1e_3,
MAE_1e_4,
MAE_1e_5,
MAE_1e_6,
// @}

/** Number of terms in polynomial -- Optimized for Max Absolute Error. */
// @{
MAE_Poly2,
Expand All @@ -1015,19 +1024,41 @@ enum class ApproximationPrecision {
MULPE_Poly7,
MULPE_Poly8,
// @}

/* Maximum Absolute Error Optimized with given Maximal Absolute Error. */
// @{
MAE_1e_2,
MAE_1e_3,
MAE_1e_4,
MAE_1e_5,
MAE_1e_6,
// @}

/* Maximum ULP Error Optimized with given Maximal Absolute Error. */
// @{
MULPE_1e_2,
MULPE_1e_3,
MULPE_1e_4,
MULPE_1e_5,
MULPE_1e_6,
// @}
};

/** Fast vectorizable approximations for arctan and arctan2 for Float(32).
*
* Desired precision can be specified as either a maximum absolute error (MAE) or
* the number of terms in the polynomial approximation (see the ApproximationPrecision enum) which
* are optimized for either:
* - MSE (Mean Squared Error)
* - MAE (Maximum Absolute Error)
* - MULPE (Maximum Units in Last Place Error).
* The default (Max ULP Error Polynomial 6) has a MAE of 3.53e-6. For more info on the precision,
* see the table in IROperator.cpp.
*
* The default (Max ULP Error Polynomial 6) has a MAE of 3.53e-6.
* For more info on the precision, see the table in IROperator.cpp.
*
* Note: the polynomial uses odd powers, so the number of terms is not the degree of the polynomial.
* Note: Poly8 is only useful to increase precision for atan, and not for atan2.
* Note: The performance of this functions seem to be not reliably faster on WebGPU (for now, August 2024).
*/
// @{
Expr fast_atan(const Expr &x, ApproximationPrecision precision = ApproximationPrecision::MULPE_Poly6);
Expand Down
47 changes: 39 additions & 8 deletions test/correctness/fast_arctan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,46 @@

using namespace Halide;

int bits_diff(float fa, float fb) {
uint32_t a = Halide::Internal::reinterpret_bits<uint32_t>(fa);
uint32_t b = Halide::Internal::reinterpret_bits<uint32_t>(fb);
uint32_t a_exp = a >> 23;
uint32_t b_exp = b >> 23;
if (a_exp != b_exp) return -100;
uint32_t diff = a > b ? a - b : b - a;
int count = 0;
while (diff) {
count++;
diff /= 2;
}
return count;
}

int main(int argc, char **argv) {
Target target = get_jit_target_from_environment();

struct Prec {
ApproximationPrecision precision;
float epsilon;
const char *objective;
} precisions_to_test[] = {
{ApproximationPrecision::MAE_1e_2, 1e-2f},
{ApproximationPrecision::MAE_1e_3, 1e-3f},
{ApproximationPrecision::MAE_1e_4, 1e-4f},
{ApproximationPrecision::MAE_1e_5, 1e-5f},
{ApproximationPrecision::MAE_1e_6, 1e-6f}};
// MAE
{ApproximationPrecision::MAE_1e_2, 1e-2f, "MAE"},
{ApproximationPrecision::MAE_1e_3, 1e-3f, "MAE"},
{ApproximationPrecision::MAE_1e_4, 1e-4f, "MAE"},
{ApproximationPrecision::MAE_1e_5, 1e-5f, "MAE"},
{ApproximationPrecision::MAE_1e_6, 1e-6f, "MAE"},

// MULPE
{ApproximationPrecision::MULPE_1e_2, 1e-2f, "MULPE"},
{ApproximationPrecision::MULPE_1e_3, 1e-3f, "MULPE"},
{ApproximationPrecision::MULPE_1e_4, 1e-4f, "MULPE"},
{ApproximationPrecision::MULPE_1e_5, 1e-5f, "MULPE"},
{ApproximationPrecision::MULPE_1e_6, 1e-6f, "MULPE"},
};

for (Prec precision : precisions_to_test) {
printf("\nTesting for precision %e...\n", precision.epsilon);
printf("\nTesting for precision %.1e (%s optimized)...\n", precision.epsilon, precision.objective);
Func atan_f, atan2_f;
Var x, y;
const int steps = 1000;
Expand All @@ -36,18 +61,21 @@ int main(int argc, char **argv) {
printf(" Testing fast_atan() correctness... ");
Buffer<float> atan_result = atan_f.realize({steps});
float max_error = 0.0f;
int max_mantissa_error = 0;
for (int i = 0; i < steps; ++i) {
const float x = (i - steps / 2) / float(steps / 8);
const float atan_x = atan_result(i);
const float atan_x_ref = atan(x);
float abs_error = std::abs(atan_x_ref - atan_x);
int mantissa_error = bits_diff(atan_x, atan_x_ref);
max_error = std::max(max_error, abs_error);
max_mantissa_error = std::max(max_mantissa_error, mantissa_error);
if (abs_error > precision.epsilon) {
fprintf(stderr, "fast_atan(%.6f) = %.20f not equal to %.20f (error=%.5e)\n", x, atan_x, atan_x_ref, atan_x_ref - atan_x);
exit(1);
}
}
printf("Passed: max abs error: %.5e\n", max_error);
printf("Passed: max abs error: %.5e max mantissa bits wrong: %d\n", max_error, max_mantissa_error);

atan2_f(x, y) = fast_atan2(vx, vy, precision.precision);
if (target.has_gpu_feature()) {
Expand All @@ -61,21 +89,24 @@ int main(int argc, char **argv) {
printf(" Testing fast_atan2() correctness... ");
Buffer<float> atan2_result = atan2_f.realize({steps, steps});
max_error = 0.0f;
max_mantissa_error = 0;
for (int i = 0; i < steps; ++i) {
const float x = (i - steps / 2) / float(steps / 8);
for (int j = 0; j < steps; ++j) {
const float y = (j - steps / 2) / float(steps / 8);
const float atan2_x_y = atan2_result(i, j);
const float atan2_x_y_ref = atan2(x, y);
float abs_error = std::abs(atan2_x_y_ref - atan2_x_y);
int mantissa_error = bits_diff(atan2_x_y, atan2_x_y_ref);
max_error = std::max(max_error, abs_error);
max_mantissa_error = std::max(max_mantissa_error, mantissa_error);
if (abs_error > precision.epsilon) {
fprintf(stderr, "fast_atan2(%.6f, %.6f) = %.20f not equal to %.20f (error=%.5e)\n", x, y, atan2_x_y, atan2_x_y_ref, atan2_x_y_ref - atan2_x_y);
exit(1);
}
}
}
printf("Passed: max abs error: %.5e\n", max_error);
printf("Passed: max abs error: %.5e max mantissa bits wrong: %d\n", max_error, max_mantissa_error);
}

printf("Success!\n");
Expand Down
4 changes: 4 additions & 0 deletions test/performance/fast_arctan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ int main(int argc, char **argv) {
printf("[SKIP] Performance tests are meaningless and/or misleading under WebAssembly interpreter.\n");
return 0;
}
if (target.has_feature(Target::WebGPU)) {
printf("[SKIP] WebGPU seems to perform bad, and fast_atan is not really faster in all scenarios.\n");
return 0;
}

Var x, y;
const int test_w = 256;
Expand Down
Loading