Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast atan and atan2 functions. #8388

Closed
wants to merge 28 commits into from
Closed

Conversation

mcourteaux
Copy link
Contributor

@mcourteaux mcourteaux commented Aug 10, 2024

Addresses #8243. Uses a polynomial approximation with odd powers: this way, it's immediately symmetrical around 0. Coefficients are optimized using my script which does iterative weight-adjusted least-squared-error (also included in PR; see below).

Added API

/** Struct that allows the user to specify several requirements for functions
 * that are approximated by polynomial expansions. These polynomials can be
 * optimized for four different metrics: Mean Squared Error, Maximum Absolute Error,
 * Maximum Units in Last Place (ULP) Error, or a 50%/50% blend of MAE and MULPE.
 *
 * Orthogonally to the optimization objective, these polynomials can vary
 * in degree. Higher degree polynomials will give more precise results.
 * Note that instead of specifying the degree, the number of terms is used instead.
 * E.g., even (i.e., symmetric) functions may be implemented using only even powers,
 * for which a number of terms of 4 would actually mean that terms
 * in [1, x^2, x^4, x^6] are used, which is degree 6.
 *
 * Additionally, if you don't care about number of terms in the polynomial
 * and you do care about the maximal absolute error the approximation may have
 * over the domain, you may specify values and the implementation
 * will decide the appropriate polynomial degree that achieves this precision.
 */
struct ApproximationPrecision {
    enum OptimizationObjective {
        MSE,        //< Mean Squared Error Optimized.
        MAE,        //< Optimized for Max Absolute Error.
        MULPE,      //< Optimized for Max ULP Error. ULP is "Units in Last Place", measured in IEEE 32-bit floats.
        MULPE_MAE,  //< Optimized for simultaneously Max ULP Error, and Max Absolute Error, each with a weight of 50%.
    } optimized_for;
    int constraint_min_poly_terms{0};           //< Number of terms in polynomial (zero for no constraint).
    float constraint_max_absolute_error{0.0f};  //< Max absolute error (zero for no constraint).
};

/** Fast vectorizable approximations for arctan and arctan2 for Float(32).
 *
 * Desired precision can be specified as either a maximum absolute error (MAE) or
 * the number of terms in the polynomial approximation (see the ApproximationPrecision enum) which
 * are optimized for either:
 *  - MSE (Mean Squared Error)
 *  - MAE (Maximum Absolute Error)
 *  - MULPE (Maximum Units in Last Place Error).
 *
 * The default (Max ULP Error Polynomial of 6 terms) has a MAE of 3.53e-6.
 * For more info on the available approximations and their precisions, see the table in ApproximationTables.cpp.
 *
 * Note: the polynomial uses odd powers, so the number of terms is not the degree of the polynomial.
 * Note: the polynomial with 8 terms is only useful to increase precision for fast_atan, and not for fast_atan2.
 * Note: the performance of this functions seem to be not reliably faster on WebGPU (for now, August 2024).
 */
// @{
Expr fast_atan(const Expr &x, ApproximationPrecision precision = {ApproximationPrecision::MULPE, 6});
Expr fast_atan2(const Expr &y, const Expr &x, ApproximationPrecision = {ApproximationPrecision::MULPE, 6});
// @}

I designed this new ApproximationPrecision such that it can be used for other vectorizable functions at a later point as well, such as for fast_sin and fast_cos if we want that at some point. Note that I chose for MAE_1e_5 style of notation, instead of 5Decimals because 5 decimals suggests that there will be 5 decimals correct, which is technically less correct than saying that the maximal absolute error will be below 1e-5.

Performance difference:

Linux/CPU (with precision MAE_1e_5):

                  atan: 7.427325 ns per atan
 fast_atan (MAE 1e-02): 0.604592 ns per atan (91.9% faster)  [per invokation: 2.535843 ms]
 fast_atan (MAE 1e-03): 0.695281 ns per atan (90.6% faster)  [per invokation: 2.916222 ms]
 fast_atan (MAE 1e-04): 0.787722 ns per atan (89.4% faster)  [per invokation: 3.303945 ms]
 fast_atan (MAE 1e-05): 0.863543 ns per atan (88.4% faster)  [per invokation: 3.621961 ms]
 fast_atan (MAE 1e-06): 0.951112 ns per atan (87.2% faster)  [per invokation: 3.989254 ms]

                  atan2: 13.759876 ns per atan2
 fast_atan2 (MAE 1e-02): 1.052900 ns per atan2 (92.3% faster)  [per invokation: 4.416183 ms]
 fast_atan2 (MAE 1e-03): 1.124720 ns per atan2 (91.8% faster)  [per invokation: 4.717417 ms]
 fast_atan2 (MAE 1e-04): 1.245389 ns per atan2 (90.9% faster)  [per invokation: 5.223540 ms]
 fast_atan2 (MAE 1e-05): 1.304229 ns per atan2 (90.5% faster)  [per invokation: 5.470334 ms]
 fast_atan2 (MAE 1e-06): 1.407788 ns per atan2 (89.8% faster)  [per invokation: 5.904690 ms]
Success!

On Linux/CUDA, it's slightly faster than the default LLVM implementation (there is no atan instruction in PTX):

                  atan: 0.012694 ns per atan
 fast_atan (MAE 1e-02): 0.008084 ns per atan (36.3% faster)  [per invokation: 0.542537 ms]
 fast_atan (MAE 1e-03): 0.008257 ns per atan (35.0% faster)  [per invokation: 0.554145 ms]
 fast_atan (MAE 1e-04): 0.008580 ns per atan (32.4% faster)  [per invokation: 0.575821 ms]
 fast_atan (MAE 1e-05): 0.009693 ns per atan (23.6% faster)  [per invokation: 0.650511 ms]
 fast_atan (MAE 1e-06): 0.009996 ns per atan (21.3% faster)  [per invokation: 0.670806 ms]

                  atan2: 0.016339 ns per atan2
 fast_atan2 (MAE 1e-02): 0.010460 ns per atan2 (36.0% faster)  [per invokation: 0.701942 ms]
 fast_atan2 (MAE 1e-03): 0.010887 ns per atan2 (33.4% faster)  [per invokation: 0.730619 ms]
 fast_atan2 (MAE 1e-04): 0.011134 ns per atan2 (31.9% faster)  [per invokation: 0.747207 ms]
 fast_atan2 (MAE 1e-05): 0.011699 ns per atan2 (28.4% faster)  [per invokation: 0.785120 ms]
 fast_atan2 (MAE 1e-06): 0.012122 ns per atan2 (25.8% faster)  [per invokation: 0.813505 ms]
Success!

On Linux/OpenCL, it is also slightly faster:

                  atan: 0.012427 ns per atan
 fast_atan (MAE 1e-02): 0.008740 ns per atan (29.7% faster)  [per invokation: 0.586513 ms]
 fast_atan (MAE 1e-03): 0.008920 ns per atan (28.2% faster)  [per invokation: 0.598603 ms]
 fast_atan (MAE 1e-04): 0.009326 ns per atan (25.0% faster)  [per invokation: 0.625840 ms]
 fast_atan (MAE 1e-05): 0.010362 ns per atan (16.6% faster)  [per invokation: 0.695404 ms]
 fast_atan (MAE 1e-06): 0.011196 ns per atan ( 9.9% faster)  [per invokation: 0.751366 ms]

                  atan2: 0.016028 ns per atan2
 fast_atan2 (MAE 1e-02): 0.011978 ns per atan2 (25.3% faster)  [per invokation: 0.803816 ms]
 fast_atan2 (MAE 1e-03): 0.011715 ns per atan2 (26.9% faster)  [per invokation: 0.786199 ms]
 fast_atan2 (MAE 1e-04): 0.011774 ns per atan2 (26.5% faster)  [per invokation: 0.790166 ms]
 fast_atan2 (MAE 1e-05): 0.012266 ns per atan2 (23.5% faster)  [per invokation: 0.823142 ms]
 fast_atan2 (MAE 1e-06): 0.012728 ns per atan2 (20.6% faster)  [per invokation: 0.854140 ms]
Success!

Precision tests:

Testing for precision 1.000000e-02...
    Testing fast_atan() correctness...  Passed: max abs error: 4.94057e-03
    Testing fast_atan2() correctness...  Passed: max abs error: 4.99773e-03

Testing for precision 1.000000e-03...
    Testing fast_atan() correctness...  Passed: max abs error: 6.07625e-04
    Testing fast_atan2() correctness...  Passed: max abs error: 6.13213e-04

Testing for precision 1.000000e-04...
    Testing fast_atan() correctness...  Passed: max abs error: 8.12709e-05
    Testing fast_atan2() correctness...  Passed: max abs error: 8.20160e-05

Testing for precision 1.000000e-05...
    Testing fast_atan() correctness...  Passed: max abs error: 1.69873e-06
    Testing fast_atan2() correctness...  Passed: max abs error: 1.90735e-06

Testing for precision 1.000000e-06...
    Testing fast_atan() correctness...  Passed: max abs error: 2.98023e-07
    Testing fast_atan2() correctness...  Passed: max abs error: 4.76837e-07
Success!

Optimizer

This PR includes a Python optimization script to find the coefficients of the polynomials:

atan_poly5_optimization

While I didn't do anything very scientific or looked at research papers, I get a hunch that the results from this script are really good (and may actually converge to optimal).

If my optimization makes sense, then I have some funny observation: I get different coefficients for all of the fast approximations we have. See below.

Better coefficients for exp()?

My result:

// Coefficients with max error: 1.0835e-07
const float c_0(9.999998916957e-01f);
const float c_1(1.000010959810e+00f);
const float c_2(4.998191326645e-01f);
const float c_3(1.677545067148e-01f);
const float c_4(3.874100973369e-02f);
const float c_5(1.185256835401e-02f);

versus current Halide code:

Halide/src/IROperator.cpp

Lines 1432 to 1439 in 3cdeb53

float coeff[] = {
0.01314350012789660196f,
0.03668965196652099192f,
0.16873890085469545053f,
0.49970514590562437052f,
1.0f,
1.0f};
Expr result = evaluate_polynomial(x, coeff, sizeof(coeff) / sizeof(coeff[0]));

Better coefficients for sin()?

// Coefficients with max error: 1.3500e-11
const float c_1(9.999999998902e-01f);
const float c_3(-1.666666654172e-01f);
const float c_5(8.333329271330e-03f);
const float c_7(-1.984070354590e-04f);
const float c_9(2.751888510663e-06f);
const float c_11(-2.379517255457e-08f);

Notice that my optimization gives maximal error of 1.35e-11, instead of the promised 1e-5, with degree 6.

Versus:

Halide/src/IROperator.cpp

Lines 1390 to 1394 in 3cdeb53

const float sin_c2 = -0.16666667163372039794921875f;
const float sin_c4 = 8.333347737789154052734375e-3;
const float sin_c6 = -1.9842604524455964565277099609375e-4;
const float sin_c8 = 2.760012648650445044040679931640625e-6;
const float sin_c10 = -2.50293279435709337121807038784027099609375e-8;

If this is true (I don't see a reason why it wouldn't), that would mean we can remove a few terms to get faster version that still provides the promised precision.

Better coefficients for cos()?

// Coefficients with max error: 2.2274e-10
const float c_0(9.999999997814e-01f);
const float c_2(-4.999999936010e-01f);
const float c_4(4.166663631608e-02f);
const float c_6(-1.388836211466e-03f);
const float c_8(2.476019687789e-05f);
const float c_10(-2.605210837614e-07f);

versus:

Halide/src/IROperator.cpp

Lines 1396 to 1400 in 3cdeb53

const float cos_c2 = -0.5f;
const float cos_c4 = 4.166664183139801025390625e-2;
const float cos_c6 = -1.388833043165504932403564453125e-3;
const float cos_c8 = 2.47562347794882953166961669921875e-5;
const float cos_c10 = -2.59630184018533327616751194000244140625e-7;

Better coefficients for log()?

// Coefficients with max error: 2.2155e-08
const float c_0(2.215451521194e-08f);
const float c_1(9.999956758035e-01f);
const float c_2(-4.998600090003e-01f);
const float c_3(3.315834102478e-01f);
const float c_4(-2.389843462478e-01f);
const float c_5(1.605007787295e-01f);
const float c_6(-8.022296753549e-02f);
const float c_7(2.030898293785e-02f);

versus:

Halide/src/IROperator.cpp

Lines 1357 to 1365 in 3cdeb53

float coeff[] = {
0.07640318789187280912f,
-0.16252961013874300811f,
0.20625219040645212387f,
-0.25110261010892864775f,
0.33320464908377461777f,
-0.49997513376789826101f,
1.0f,
0.0f};

@mcourteaux
Copy link
Contributor Author

mcourteaux commented Aug 11, 2024

Apparently Windows/OpenCL on the build bot does not have a performance improvement, but even a performance degradation (about 15%):

C:\build_bot\worker\halide-testbranch-main-llvm20-x86-64-windows-cmake\halide-build\bin\performance_fast_arctan.exe
atan: 6.347030 ns per pixel
fast_atan: 7.295760 ns per pixel
atan2: 0.923191 ns per pixel
fast_atan2: 0.926148 ns per pixel
fast_atan more than 10% slower than atan on GPU.

Suggestions?

@mcourteaux
Copy link
Contributor Author

GPU performance test was severely memory bandwidth limited. This has been worked around by computing many (1024) arctans per output and summing them. Now --at least on my system-- they are faster. See updated performance reports.

@mcourteaux
Copy link
Contributor Author

Okay, this is ready for review. Vulkan is slow, but that is apparently known well...

@mcourteaux
Copy link
Contributor Author

Oh dear... I don't even know what WebGPU is... @steven-johnson Is this supposed to be an actual platform that is fast, and where performance metrics make sense? I can treat it like Vulkan, where it's just "meh, at least some are faster..."?

@steven-johnson
Copy link
Contributor

Oh dear... I don't even know what WebGPU is... @steven-johnson Is this supposed to be an actual platform that is fast, and where performance metrics make sense? I can treat it like Vulkan, where it's just "meh, at least some are faster..."?

https://en.wikipedia.org/wiki/WebGPU
https://www.w3.org/TR/webgpu/
https://github.com/gpuweb/gpuweb/wiki/Implementation-Status

@derek-gerstmann
Copy link
Contributor

Okay, this is ready for review. Vulkan is slow, but that is apparently known well...

I don't think Vulkan is necessarily slow ... I think the benchmark loop is including initialization overhead. See my follow up here: #7202

@abadams
Copy link
Member

abadams commented Aug 13, 2024

Very cool! I have some concerns with the error metric though. Decimal digits of error isn't a great metric. E.g. having a value of 0.0001 when it's supposed to be zero is much much worse than having a value of 0.3701 when it's supposed to be 0.37. Relative error isn't great either, due to the singularity at zero. A better metric is ULPs, which is the maximum number of distinct floating point values in between the answer and the correct answer.

There are also cases where you want a hard constraint as opposed to a minimization. exp(0) should be exactly one, and I guess I decided its derivative should be exactly one too, which explains the different in coefficients.

@mcourteaux
Copy link
Contributor Author

A better metric is ULPs, which is the maximum number of distinct floating point values in between the answer and the correct answer.

@abadams I improved the optimization script a lot. I added support for ULP optimization: it optimizes very nicely for maximal bit error.

atan_6_mulpe

When instead optimizing for MAE, we see the max ULP distance increase:

atan_6_mae

I changed the default to the ULP-optimized one, but to keep the maximal absolute error under 1e-5, I had to choose the higher-degree polynomial. Overall still good.

@derek-gerstmann Thanks a lot for investigating the performance issue! I now also get very fast Vulkan performance. I wonder why the overhead is so huge in Vulkan, and not there in other backends?

Vulkan:

              atan: 0.009071 ns per atan
 fast_atan (Poly2): 0.005076 ns per atan (44.0% faster)  [per invokation: 0.340618 ms]
 fast_atan (Poly3): 0.005279 ns per atan (41.8% faster)  [per invokation: 0.354284 ms]
 fast_atan (Poly4): 0.005484 ns per atan (39.5% faster)  [per invokation: 0.368018 ms]
 fast_atan (Poly5): 0.005925 ns per atan (34.7% faster)  [per invokation: 0.397631 ms]
 fast_atan (Poly6): 0.006225 ns per atan (31.4% faster)  [per invokation: 0.417756 ms]
 fast_atan (Poly7): 0.006448 ns per atan (28.9% faster)  [per invokation: 0.432734 ms]
 fast_atan (Poly8): 0.006765 ns per atan (25.4% faster)  [per invokation: 0.453989 ms]

              atan2: 0.013717 ns per atan2
 fast_atan2 (Poly2): 0.007812 ns per atan2 (43.0% faster)  [per invokation: 0.524279 ms]
 fast_atan2 (Poly3): 0.007604 ns per atan2 (44.6% faster)  [per invokation: 0.510290 ms]
 fast_atan2 (Poly4): 0.008016 ns per atan2 (41.6% faster)  [per invokation: 0.537952 ms]
 fast_atan2 (Poly5): 0.008544 ns per atan2 (37.7% faster)  [per invokation: 0.573364 ms]
 fast_atan2 (Poly6): 0.008204 ns per atan2 (40.2% faster)  [per invokation: 0.550533 ms]
 fast_atan2 (Poly7): 0.008757 ns per atan2 (36.2% faster)  [per invokation: 0.587663 ms]
 fast_atan2 (Poly8): 0.008629 ns per atan2 (37.1% faster)  [per invokation: 0.579092 ms]
Success!

CUDA:

              atan: 0.010663 ns per atan
 fast_atan (Poly2): 0.006854 ns per atan (35.7% faster)  [per invokation: 0.459946 ms]
 fast_atan (Poly3): 0.006838 ns per atan (35.9% faster)  [per invokation: 0.458894 ms]
 fast_atan (Poly4): 0.007196 ns per atan (32.5% faster)  [per invokation: 0.482914 ms]
 fast_atan (Poly5): 0.007646 ns per atan (28.3% faster)  [per invokation: 0.513141 ms]
 fast_atan (Poly6): 0.008205 ns per atan (23.1% faster)  [per invokation: 0.550595 ms]
 fast_atan (Poly7): 0.008496 ns per atan (20.3% faster)  [per invokation: 0.570149 ms]
 fast_atan (Poly8): 0.009008 ns per atan (15.5% faster)  [per invokation: 0.604508 ms]

              atan2: 0.014594 ns per atan2
 fast_atan2 (Poly2): 0.009409 ns per atan2 (35.5% faster)  [per invokation: 0.631451 ms]
 fast_atan2 (Poly3): 0.009957 ns per atan2 (31.8% faster)  [per invokation: 0.668201 ms]
 fast_atan2 (Poly4): 0.010289 ns per atan2 (29.5% faster)  [per invokation: 0.690511 ms]
 fast_atan2 (Poly5): 0.010255 ns per atan2 (29.7% faster)  [per invokation: 0.688207 ms]
 fast_atan2 (Poly6): 0.010748 ns per atan2 (26.4% faster)  [per invokation: 0.721268 ms]
 fast_atan2 (Poly7): 0.011497 ns per atan2 (21.2% faster)  [per invokation: 0.771529 ms]
 fast_atan2 (Poly8): 0.011326 ns per atan2 (22.4% faster)  [per invokation: 0.760067 ms]
Success!

Vulkan is now even faster than CUDA! 🤯

@mcourteaux
Copy link
Contributor Author

mcourteaux commented Aug 13, 2024

@steven-johnson The build just broke on something LLVM related it seems... There seems to be no related commit to Halide. Does LLVM constantly update with every build?

Edit: I found the commit: llvm/llvm-project@75c7bca

Fix separately PR'd in #8391

@steven-johnson
Copy link
Contributor

@steven-johnson The build just broke on something LLVM related it seems... There seems to be no related commit to Halide. Does LLVM constantly update with every build?

We rebuild LLVM once a day, about 2AM Pacific time.

@mcourteaux
Copy link
Contributor Author

@abadams I added the check that counts number of wrong mantissa bits:

Testing for precision 1.0e-02 (MAE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 4.96906e-03  max mantissa bits wrong: 19
    Testing fast_atan2() correctness...  Passed: max abs error: 4.96912e-03  max mantissa bits wrong: 19

Testing for precision 1.0e-03 (MAE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 6.10709e-04  max mantissa bits wrong: 17
    Testing fast_atan2() correctness...  Passed: max abs error: 6.10709e-04  max mantissa bits wrong: 17

Testing for precision 1.0e-04 (MAE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 8.16584e-05  max mantissa bits wrong: 14
    Testing fast_atan2() correctness...  Passed: max abs error: 8.17776e-05  max mantissa bits wrong: 14

Testing for precision 1.0e-05 (MAE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 1.78814e-06  max mantissa bits wrong: 9
    Testing fast_atan2() correctness...  Passed: max abs error: 1.90735e-06  max mantissa bits wrong: 9

Testing for precision 1.0e-06 (MAE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 3.57628e-07  max mantissa bits wrong: 6
    Testing fast_atan2() correctness...  Passed: max abs error: 4.76837e-07  max mantissa bits wrong: 7

Testing for precision 1.0e-02 (MULPE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 1.31637e-03  max mantissa bits wrong: 15
    Testing fast_atan2() correctness...  Passed: max abs error: 1.31637e-03  max mantissa bits wrong: 15

Testing for precision 1.0e-03 (MULPE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 1.54853e-04  max mantissa bits wrong: 12
    Testing fast_atan2() correctness...  Passed: max abs error: 1.54972e-04  max mantissa bits wrong: 12

Testing for precision 1.0e-04 (MULPE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 2.53320e-05  max mantissa bits wrong: 9
    Testing fast_atan2() correctness...  Passed: max abs error: 2.55108e-05  max mantissa bits wrong: 9

Testing for precision 1.0e-05 (MULPE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 3.63588e-06  max mantissa bits wrong: 6
    Testing fast_atan2() correctness...  Passed: max abs error: 3.81470e-06  max mantissa bits wrong: 6

Testing for precision 1.0e-06 (MULPE optimized)...
    Testing fast_atan() correctness...  Passed: max abs error: 5.96046e-07  max mantissa bits wrong: 4
    Testing fast_atan2() correctness...  Passed: max abs error: 7.15256e-07  max mantissa bits wrong: 4
Success!

Pay attention to the MULPE optimized ones: they are significantly lower than the MAE optimized ones.

Copy link
Contributor

@steven-johnson steven-johnson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but would like @abadams or @zvookin to weigh in as well

@steven-johnson
Copy link
Contributor

Ping to @abadams or @zvookin for review

@mcourteaux
Copy link
Contributor Author

Cut polynomial + merge it + later take care of other transcendentals.

@mcourteaux
Copy link
Contributor Author

@abadams I updated the PR, and believe this is a nice compromise of options. It is in line with your initial thoughts on just specifying the precision yourself. I have made a table of approximations and their precisions. Then a new auxiliary function selects an approximation from that table that satisfies your requirements. This clears out the header (no more one million enum options), and clears out the source file, by not having the table sitting inside of the fast_atan function.

@steven-johnson
Copy link
Contributor

Looks like this is ready for final review... ?

Copy link
Contributor

@steven-johnson steven-johnson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nits

}

double score = obj_score + term_count_score + precision_score - penalty;
// std::printf("Score for %zu (%zu terms): %f = %d + %d + %f - penalty %f\n", i, e.coefficients.size(), score, obj_score, term_count_score, precision_score, penalty);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove commented-out code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrapped in an #if

@@ -0,0 +1,21 @@
#pragma once
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Halide doesn't use #pragma once; instead, wrap in

#ifndef HALIDE_APPROXIMATION_TABLES_H_
#define HALIDE_APPROXIMATION_TABLES_H_
...

#endif

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed that, but without the trailing _, as that seemed to be the style, looking at other files.

@@ -219,8 +219,7 @@ target_sources(
WrapCalls.h
)

# The sources that go into libHalide. For the sake of IDE support, headers that
# exist in src/ but are not public should be included here.
# The sources that go into libHalide.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you alter the comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there are no headers in that list. That comments is clearly outdated. Unless I'm wildly misunderstanding something.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mcourteaux -- you understand it fine, the comment is out of date. The list of headers is right above.

@steven-johnson
Copy link
Contributor

Is this ready to land (pending review comments)?

@alexreinking
Copy link
Member

One nit: can we move the Python script from src/ to tools/. That's where our other code generators go.

@mcourteaux
Copy link
Contributor Author

Performance of new functions on x64 with AVX2:

      sin           :   6.27932 ns per evaluation  [per invokation: 52.675 ms]
 fast_sin (   Poly2):   0.89839 ns per evaluation  [per invokation:  7.536 ms] [force_approx   85.7% faster]
 fast_sin (   Poly3):   0.90588 ns per evaluation  [per invokation:  7.599 ms] [force_approx   85.6% faster]
 fast_sin (   Poly4):   0.94848 ns per evaluation  [per invokation:  7.956 ms] [force_approx   84.9% faster]
 fast_sin (   Poly5):   0.96361 ns per evaluation  [per invokation:  8.083 ms] [force_approx   84.7% faster]
 fast_sin (   Poly6):   1.06084 ns per evaluation  [per invokation:  8.899 ms] [force_approx   83.1% faster]
 fast_sin (   Poly7):   1.13292 ns per evaluation  [per invokation:  9.504 ms] [force_approx   82.0% faster]
 fast_sin (   Poly8):   1.17670 ns per evaluation  [per invokation:  9.871 ms] [force_approx   81.3% faster]
 fast_sin (MAE 1e-2):   0.78337 ns per evaluation  [per invokation:  6.571 ms] [force_approx   87.5% faster]
 fast_sin (MAE 1e-3):   0.83185 ns per evaluation  [per invokation:  6.978 ms] [force_approx   86.8% faster]
 fast_sin (MAE 1e-4):   0.83603 ns per evaluation  [per invokation:  7.013 ms] [force_approx   86.7% faster]
 fast_sin (MAE 1e-5):   0.88843 ns per evaluation  [per invokation:  7.453 ms] [force_approx   85.9% faster]
 fast_sin (MAE 1e-6):   0.89462 ns per evaluation  [per invokation:  7.505 ms] [force_approx   85.8% faster]
 fast_sin (MAE 1e-7):   0.97449 ns per evaluation  [per invokation:  8.175 ms] [force_approx   84.5% faster]
 fast_sin (MAE 1e-8):   0.96633 ns per evaluation  [per invokation:  8.106 ms] [force_approx   84.6% faster]

      cos           :   4.14441 ns per evaluation  [per invokation: 34.766 ms]
 fast_cos (   Poly2):   0.87214 ns per evaluation  [per invokation:  7.316 ms] [force_approx   79.0% faster]
 fast_cos (   Poly3):   0.91229 ns per evaluation  [per invokation:  7.653 ms] [force_approx   78.0% faster]
 fast_cos (   Poly4):   0.97219 ns per evaluation  [per invokation:  8.155 ms] [force_approx   76.5% faster]
 fast_cos (   Poly5):   1.03300 ns per evaluation  [per invokation:  8.665 ms] [force_approx   75.1% faster]
 fast_cos (   Poly6):   1.09260 ns per evaluation  [per invokation:  9.165 ms] [force_approx   73.6% faster]
 fast_cos (   Poly7):   1.16861 ns per evaluation  [per invokation:  9.803 ms] [force_approx   71.8% faster]
 fast_cos (   Poly8):   1.22586 ns per evaluation  [per invokation: 10.283 ms] [force_approx   70.4% faster]
 fast_cos (MAE 1e-2):   0.86472 ns per evaluation  [per invokation:  7.254 ms] [force_approx   79.1% faster]
 fast_cos (MAE 1e-3):   0.88472 ns per evaluation  [per invokation:  7.422 ms] [force_approx   78.7% faster]
 fast_cos (MAE 1e-4):   0.88596 ns per evaluation  [per invokation:  7.432 ms] [force_approx   78.6% faster]
 fast_cos (MAE 1e-5):   0.96200 ns per evaluation  [per invokation:  8.070 ms] [force_approx   76.8% faster]
 fast_cos (MAE 1e-6):   0.96340 ns per evaluation  [per invokation:  8.082 ms] [force_approx   76.8% faster]
 fast_cos (MAE 1e-7):   1.02604 ns per evaluation  [per invokation:  8.607 ms] [force_approx   75.2% faster]
 fast_cos (MAE 1e-8):   1.02788 ns per evaluation  [per invokation:  8.622 ms] [force_approx   75.2% faster]

      exp           :   1.00127 ns per evaluation  [per invokation:  8.399 ms]
 fast_exp (   Poly2):   0.63113 ns per evaluation  [per invokation:  5.294 ms] [force_approx   37.0% faster]
 fast_exp (   Poly3):   0.69646 ns per evaluation  [per invokation:  5.842 ms] [force_approx   30.4% faster]
 fast_exp (   Poly4):   0.78186 ns per evaluation  [per invokation:  6.559 ms] [force_approx   21.9% faster]
 fast_exp (   Poly5):   0.82727 ns per evaluation  [per invokation:  6.940 ms] [force_approx   17.4% faster]
 fast_exp (   Poly6):   0.88910 ns per evaluation  [per invokation:  7.458 ms] [force_approx   11.2% faster]
 fast_exp (   Poly7):   0.96408 ns per evaluation  [per invokation:  8.087 ms] [force_approx   equally fast ( +3.7% faster)]
 fast_exp (   Poly8):   0.96186 ns per evaluation  [per invokation:  8.069 ms] [force_approx   equally fast ( +3.9% faster)]
 fast_exp (MAE 1e-2):   0.59507 ns per evaluation  [per invokation:  4.992 ms] [force_approx   40.6% faster]
 fast_exp (MAE 1e-3):   0.63098 ns per evaluation  [per invokation:  5.293 ms] [force_approx   37.0% faster]
 fast_exp (MAE 1e-4):   0.69734 ns per evaluation  [per invokation:  5.850 ms] [force_approx   30.4% faster]
 fast_exp (MAE 1e-5):   0.69197 ns per evaluation  [per invokation:  5.805 ms] [force_approx   30.9% faster]
 fast_exp (MAE 1e-6):   0.76765 ns per evaluation  [per invokation:  6.440 ms] [force_approx   23.3% faster]
 fast_exp (MAE 1e-7):   0.83347 ns per evaluation  [per invokation:  6.992 ms] [force_approx   16.8% faster]
 fast_exp (MAE 1e-8):   0.83533 ns per evaluation  [per invokation:  7.007 ms] [force_approx   16.6% faster]

      log           :   1.12527 ns per evaluation  [per invokation:  9.439 ms]
 fast_log (   Poly2):   0.51076 ns per evaluation  [per invokation:  4.285 ms] [force_approx   54.6% faster]
 fast_log (   Poly3):   0.54429 ns per evaluation  [per invokation:  4.566 ms] [force_approx   51.6% faster]
 fast_log (   Poly4):   0.65385 ns per evaluation  [per invokation:  5.485 ms] [force_approx   41.9% faster]
 fast_log (   Poly5):   0.69907 ns per evaluation  [per invokation:  5.864 ms] [force_approx   37.9% faster]
 fast_log (   Poly6):   0.76607 ns per evaluation  [per invokation:  6.426 ms] [force_approx   31.9% faster]
 fast_log (   Poly7):   0.88370 ns per evaluation  [per invokation:  7.413 ms] [force_approx   21.5% faster]
 fast_log (   Poly8):   0.94692 ns per evaluation  [per invokation:  7.943 ms] [force_approx   15.8% faster]
 fast_log (MAE 1e-2):   0.52260 ns per evaluation  [per invokation:  4.384 ms] [force_approx   53.6% faster]
 fast_log (MAE 1e-3):   0.64560 ns per evaluation  [per invokation:  5.416 ms] [force_approx   42.6% faster]
 fast_log (MAE 1e-4):   0.69668 ns per evaluation  [per invokation:  5.844 ms] [force_approx   38.1% faster]
 fast_log (MAE 1e-5):   0.77382 ns per evaluation  [per invokation:  6.491 ms] [force_approx   31.2% faster]
 fast_log (MAE 1e-6):   0.88377 ns per evaluation  [per invokation:  7.414 ms] [force_approx   21.5% faster]
 fast_log (MAE 1e-7):   0.96555 ns per evaluation  [per invokation:  8.100 ms] [force_approx   14.2% faster]
 fast_log (MAE 1e-8):   0.98732 ns per evaluation  [per invokation:  8.282 ms] [force_approx   12.3% faster]

Passed 112 / 112 performance test.
Success!

And on CUDA:

      sin           :   0.00980 ns per evaluation  [per invokation:  1.316 ms]
 fast_sin (   Poly2):   0.00479 ns per evaluation  [per invokation:  0.643 ms] [force_approx   51.2% faster]
 fast_sin (   Poly3):   0.00559 ns per evaluation  [per invokation:  0.750 ms] [force_approx   43.0% faster]
 fast_sin (   Poly4):   0.00594 ns per evaluation  [per invokation:  0.797 ms] [force_approx   39.4% faster]
 fast_sin (   Poly5):   0.00622 ns per evaluation  [per invokation:  0.835 ms] [force_approx   36.5% faster]
 fast_sin (   Poly6):   0.00600 ns per evaluation  [per invokation:  0.806 ms] [force_approx   38.8% faster]
 fast_sin (   Poly7):   0.00885 ns per evaluation  [per invokation:  1.187 ms] [force_approx    9.7% faster]
 fast_sin (   Poly8):   0.00872 ns per evaluation  [per invokation:  1.170 ms] [force_approx   11.1% faster]
 fast_sin (MAE 1e-2):   0.00481 ns per evaluation  [per invokation:  0.645 ms] [force_approx   51.0% faster]
 fast_sin (MAE 1e-3):   0.00557 ns per evaluation  [per invokation:  0.748 ms] [force_approx   43.1% faster]
 fast_sin (MAE 1e-4):   0.00584 ns per evaluation  [per invokation:  0.784 ms] [force_approx   40.4% faster]
 fast_sin (MAE 1e-5):   0.00555 ns per evaluation  [per invokation:  0.745 ms] [force_approx   43.4% faster]
 fast_sin (MAE 1e-6):   0.00554 ns per evaluation  [per invokation:  0.743 ms] [force_approx   43.5% faster]
 fast_sin (MAE 1e-7):   0.00592 ns per evaluation  [per invokation:  0.794 ms] [force_approx   39.6% faster]
 fast_sin (MAE 1e-8):   0.00593 ns per evaluation  [per invokation:  0.796 ms] [force_approx   39.5% faster]

      cos           :   0.00998 ns per evaluation  [per invokation:  1.340 ms]
 fast_cos (   Poly2):   0.00514 ns per evaluation  [per invokation:  0.689 ms] [force_approx   48.6% faster]
 fast_cos (   Poly3):   0.00547 ns per evaluation  [per invokation:  0.734 ms] [force_approx   45.2% faster]
 fast_cos (   Poly4):   0.00620 ns per evaluation  [per invokation:  0.832 ms] [force_approx   37.9% faster]
 fast_cos (   Poly5):   0.00586 ns per evaluation  [per invokation:  0.787 ms] [force_approx   41.3% faster]
 fast_cos (   Poly6):   0.00619 ns per evaluation  [per invokation:  0.831 ms] [force_approx   38.0% faster]
 fast_cos (   Poly7):   0.00672 ns per evaluation  [per invokation:  0.902 ms] [force_approx   32.7% faster]
 fast_cos (   Poly8):   0.00809 ns per evaluation  [per invokation:  1.085 ms] [force_approx   19.0% faster]
 fast_cos (MAE 1e-2):   0.00513 ns per evaluation  [per invokation:  0.689 ms] [force_approx   48.6% faster]
 fast_cos (MAE 1e-3):   0.00548 ns per evaluation  [per invokation:  0.735 ms] [force_approx   45.1% faster]
 fast_cos (MAE 1e-4):   0.00547 ns per evaluation  [per invokation:  0.735 ms] [force_approx   45.2% faster]
 fast_cos (MAE 1e-5):   0.00590 ns per evaluation  [per invokation:  0.792 ms] [force_approx   40.9% faster]
 fast_cos (MAE 1e-6):   0.00593 ns per evaluation  [per invokation:  0.796 ms] [force_approx   40.6% faster]
 fast_cos (MAE 1e-7):   0.00599 ns per evaluation  [per invokation:  0.804 ms] [force_approx   40.0% faster]
 fast_cos (MAE 1e-8):   0.00602 ns per evaluation  [per invokation:  0.808 ms] [force_approx   39.7% faster]

      exp           :   0.00585 ns per evaluation  [per invokation:  0.785 ms]
 fast_exp (   Poly2):   0.00465 ns per evaluation  [per invokation:  0.625 ms] [force_approx   20.4% faster]
 fast_exp (   Poly3):   0.00477 ns per evaluation  [per invokation:  0.640 ms] [force_approx   18.6% faster]
 fast_exp (   Poly4):   0.00482 ns per evaluation  [per invokation:  0.647 ms] [force_approx   17.6% faster]
 fast_exp (   Poly5):   0.00514 ns per evaluation  [per invokation:  0.690 ms] [force_approx   12.1% faster]
 fast_exp (   Poly6):   0.00533 ns per evaluation  [per invokation:  0.715 ms] [force_approx   equally fast ( +8.9% faster)]
 fast_exp (   Poly7):   0.00534 ns per evaluation  [per invokation:  0.717 ms] [force_approx   equally fast ( +8.7% faster)]
 fast_exp (   Poly8):   0.00535 ns per evaluation  [per invokation:  0.718 ms] [force_approx   equally fast ( +8.5% faster)]
 fast_exp (MAE 1e-2):   0.00494 ns per evaluation  [per invokation:  0.663 ms] [force_approx   15.6% faster]
 fast_exp (MAE 1e-3):   0.00466 ns per evaluation  [per invokation:  0.625 ms] [force_approx   20.4% faster]
 fast_exp (MAE 1e-4):   0.00478 ns per evaluation  [per invokation:  0.641 ms] [force_approx   18.3% faster]
 fast_exp (MAE 1e-5):   0.00480 ns per evaluation  [per invokation:  0.644 ms] [force_approx   18.0% faster]
 fast_exp (MAE 1e-6):   0.00505 ns per evaluation  [per invokation:  0.678 ms] [force_approx   13.7% faster]
 fast_exp (MAE 1e-7):   0.00517 ns per evaluation  [per invokation:  0.693 ms] [force_approx   11.7% faster]
 fast_exp (MAE 1e-8):   0.00504 ns per evaluation  [per invokation:  0.676 ms] [force_approx   13.9% faster]

      log           :   0.00616 ns per evaluation  [per invokation:  0.827 ms]
 fast_log (   Poly2):   0.00324 ns per evaluation  [per invokation:  0.435 ms] [force_approx   47.5% faster]
 fast_log (   Poly3):   0.00338 ns per evaluation  [per invokation:  0.454 ms] [force_approx   45.2% faster]
 fast_log (   Poly4):   0.00358 ns per evaluation  [per invokation:  0.480 ms] [force_approx   42.0% faster]
 fast_log (   Poly5):   0.00393 ns per evaluation  [per invokation:  0.527 ms] [force_approx   36.3% faster]
 fast_log (   Poly6):   0.00429 ns per evaluation  [per invokation:  0.575 ms] [force_approx   30.5% faster]
 fast_log (   Poly7):   0.00463 ns per evaluation  [per invokation:  0.622 ms] [force_approx   24.9% faster]
 fast_log (   Poly8):   0.00523 ns per evaluation  [per invokation:  0.702 ms] [force_approx   15.2% faster]
 fast_log (MAE 1e-2):   0.00325 ns per evaluation  [per invokation:  0.436 ms] [force_approx   47.3% faster]
 fast_log (MAE 1e-3):   0.00368 ns per evaluation  [per invokation:  0.494 ms] [force_approx   40.2% faster]
 fast_log (MAE 1e-4):   0.00409 ns per evaluation  [per invokation:  0.548 ms] [force_approx   33.7% faster]
 fast_log (MAE 1e-5):   0.00436 ns per evaluation  [per invokation:  0.585 ms] [force_approx   29.3% faster]
 fast_log (MAE 1e-6):   0.00477 ns per evaluation  [per invokation:  0.640 ms] [force_approx   22.7% faster]
 fast_log (MAE 1e-7):   0.00511 ns per evaluation  [per invokation:  0.686 ms] [force_approx   17.1% faster]
 fast_log (MAE 1e-8):   0.00560 ns per evaluation  [per invokation:  0.751 ms] [force_approx    9.2% faster]

Passed 112 / 112 performance test.
Success!

@mcourteaux
Copy link
Contributor Author

Ready-ish for review: I still need to update API documentation comments. Mostly waiting for the build bots to run all those tests on all the hardware.

@mcourteaux
Copy link
Contributor Author

mcourteaux commented Feb 4, 2025

UPDATE: My idea is coming together really nicely. No need to answer the below question.


@abadams @zvookin I have run into a conceptual issue that made me think back about @slomp's idea about "hinting" you want the fast versions.

So, on some backends, there is a fast implementation of sin/cos already available. NVIDIA CUDA with compute capability 20 is such an example. There are sin.approx.f32 and cos.approx.f32 instructions available in PTX. These are currently not generated by Halide, so I was looking into that. A quick first test shows that they are faster than a 2-degree polynomial, and more accurate than the 7-degree polynomial. So overall, they are really worth putting them in.

The issue arises now that the selection of which sine implementation to pick, depends on the target and the schedule (i.e., the target can be CUDA, but the loop in which this call sits can still be on CPU). So, ideally, this decision is made during lowering. I'm willing to give that a shot, but I'd like to get an idea on how to actually do that properly.

My guess is that I'd have to do that using Halide::Call IR node. However, I'm not sure what to represent it as: PureExtern or PureIntrinsic. On some backends it will be a polynomial approximation provided by Halide, on other backends it might be an actual instruction. My guess is that intrinsic comes the closest to what it is: "defined by the compiler/language, and implementation dependent." So, my current idea is that I implement these calls like this:

Call::make(arg.type(), Call::IntrinsicOp::fast_sin, {arg}, Call::PureIntrinsic)

Later, there will be new early lowering pass that selects which implementation to pick. As such, if the Halide-provided polynomial is chosen, constant folding and expression simplification can still take place (which can improve some accuracy).
As a note: the code in IROperator.cpp that implements these polynomial approximations would then best move to the lowering pass that either picks the backend-native intrinsic/instruction, or the Halide-provided approximation.

@mcourteaux mcourteaux marked this pull request as draft February 4, 2025 11:31
@slomp
Copy link
Contributor

slomp commented Feb 4, 2025

Almost feels as if we'd need a percentage-style argument to indicate the "fastness" we want

@abadams
Copy link
Member

abadams commented Feb 4, 2025

Good point - I agree with your proposed solution but think the intrinsic would also need to take the accuracy arg(s), which must be compile-time constants. Maybe one arg for the metric and another arg for the threshold. For precedent for all this see the pure intrinsic lerp. It's lowered on-demand in the backends, and is lowered differently in vulkan.

You could rename Lerp.{h,cpp} to e.g. MathIntrinsics.{h,cpp} and add lower_fast_sin, lower_fast_cos, etc to it.

Another less baked-into-the-compiler option would be to add functions like target_has_feature that check the DeviceAPI the code is currently in (maybe "current_device_api_is(DeviceAPI)"?), which would allow people to write math helpers that do different things depending if they're running on GPU or CPU, and would be resolved in TargetQueryOps.cpp.

Personally I think the first option is better for fast math intrinsics, because it keeps the IR more readable. But the second option sounds like a useful feature in general.

@mcourteaux
Copy link
Contributor Author

mcourteaux commented Feb 5, 2025

Preview of my work in a different branch: main...mcourteaux:Halide:fast-math-lowering

Still a work in progress, but it's going great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants