From 23a76046fc87b588aa41055fed670f85fb756242 Mon Sep 17 00:00:00 2001 From: Matthew Parno Date: Tue, 17 Oct 2023 16:55:30 -0400 Subject: [PATCH] Improved robustness of error handling in bracketed inverse. --- MParT/MonotoneComponent.h | 23 ++--------- MParT/Utilities/Miscellaneous.h | 3 ++ MParT/Utilities/RootFinding.h | 67 +++++++++++++++++++-------------- tests/Test_RootFinding.cpp | 51 +++++++++++++++++++++---- 4 files changed, 88 insertions(+), 56 deletions(-) diff --git a/MParT/MonotoneComponent.h b/MParT/MonotoneComponent.h index 6b6be68f..aa7744fb 100644 --- a/MParT/MonotoneComponent.h +++ b/MParT/MonotoneComponent.h @@ -395,6 +395,7 @@ class MonotoneComponent : public ConditionalMapBase auto functor = KOKKOS_CLASS_LAMBDA (typename Kokkos::TeamPolicy::member_type team_member) { unsigned int ptInd = team_member.league_rank () * team_member.team_size () + team_member.team_rank (); + int info; if(ptInd // Create a subview containing x_{1:d-1} auto pt = Kokkos::subview(xs, Kokkos::ALL(), xInd); - //std::cout << "\n\nPt0 = \n" << std::endl; + // Check for NaNs. If found, set output to nan and return for(unsigned int ii=0; ii::quiet_NaN(); return; } } - //std::cout << std::endl << std::endl; // Fill in the cache with everything that doesn't depend on x_d Kokkos::View cache(team_member.thread_scratch(1), cacheSize); expansion_.FillCache1(cache.data(), pt, DerivativeFlags::None); - // std::cout << "\n\nPt1 = \n" << std::endl; - // double max_val = 0.0; - // for(unsigned int ii=0; iimax_val) ? abs(pt(ii)) : max_val; - // } - // std::cout << std::endl << std::endl; - // if (max_val > 1e-4){ - // std::cout << "Max val = " << max_val; - // assert(max_val < 1e4); - // } // Compute the inverse Kokkos::View workspace(team_member.thread_scratch(1), workspaceSize); auto eval = SingleEvaluator(workspace.data(), cache.data(), pt, coeffs, quad_, expansion_); - //std::cout << "Here 0" << std::endl; - output(ptInd) = RootFinding::InverseSingleBracket(ys(ptInd), eval, pt(pt.extent(0)-1), xtol, ytol); - //std::cout << "Here 1" << std::endl; + output(ptInd) = RootFinding::InverseSingleBracket(ys(ptInd), eval, pt(pt.extent(0)-1), xtol, ytol, info); } }; diff --git a/MParT/Utilities/Miscellaneous.h b/MParT/Utilities/Miscellaneous.h index 6647bf26..b4e3b290 100644 --- a/MParT/Utilities/Miscellaneous.h +++ b/MParT/Utilities/Miscellaneous.h @@ -19,6 +19,9 @@ namespace mpart{ std::string const& key, std::string const& defaultValue); + /** Provides a mechanism for raising exceptions in CPU code where recovery is possible + and assertions in GPU code where exceptions aren't alllowed. + */ template struct ProcAgnosticError { static void error(const char*) { diff --git a/MParT/Utilities/RootFinding.h b/MParT/Utilities/RootFinding.h index 03e5c65d..971ef235 100644 --- a/MParT/Utilities/RootFinding.h +++ b/MParT/Utilities/RootFinding.h @@ -11,20 +11,25 @@ KOKKOS_INLINE_FUNCTION void swapPair(T& x1, T& x2, T& y1, T& y2) { simple_swap(y1, y2); } -/** Finds a bracket [xlb, xub] such that f(xlb)yd. */ +/** Finds a bracket [xlb, xub] such that f(xlb)yd. + * The info argument can be used to detect when a bracket cannot be found. Upon exit, a value of info=0 + * indicates success while a negative value indicates failure. info=-1 indicates that the function + * seems to be perfectly flat and a root might not exist. info=-2 indicates that the maximum number of + * iterations (128) was exceeded. +*/ template KOKKOS_INLINE_FUNCTION void FindBracket(FunctorType f, double& xlb, double& ylb, double& xub, double& yub, - const double yd) + const double yd, + int& info) { - const unsigned int maxIts = 40; + const unsigned int maxIts = 128; double stepSize = 1.0; - + info = 0; + ylb = f(xlb); yub = f(xub); - //std::cout << "-1" << ": " << xlb << ", " << ylb << ", " << xub << ", " << yub << " vs " << yd << std::endl; - // We actually found an upper bound... if(ylb>yd){ @@ -37,12 +42,10 @@ KOKKOS_INLINE_FUNCTION void FindBracket(FunctorType f, for(i=0; iyd."); + if(i>=maxIts) + info = -2; // We have a lower bound... }else{ @@ -65,13 +67,13 @@ KOKKOS_INLINE_FUNCTION void FindBracket(FunctorType f, for(i=0; iyd."); + if(i>=maxIts){ + info = -2; + } } } KOKKOS_INLINE_FUNCTION double Find_x_ITP(double xlb, double xub, double yd, double ylb, double yub, - double k1, double k2, double nhalf, double n0, int it, double xtol) { + double k1, double k2, double nhalf, double n0, int it, double xtol) { double xb = 0.5*(xub+xlb); // bisection point double xf = (xub*ylb - xlb*yub)/(ylb-yub); // regula-falsi point @@ -102,8 +105,13 @@ KOKKOS_INLINE_FUNCTION double Find_x_ITP(double xlb, double xub, double yd, doub return xc; } +/** Computes the inverse of a function using the ITP method. + * The info argument will be 0 upon successful completion and negative for failed inversions. + * A value of info=-2 indicates a failure to find a bracket that contains the root. In this case, a nan will be returned. + * A value of info=-1 indicates that the maximum number of iterations was exceeded. +*/ template -KOKKOS_INLINE_FUNCTION double InverseSingleBracket(double yd, FunctorType f, double x0, const double xtol, const double ftol) +KOKKOS_INLINE_FUNCTION double InverseSingleBracket(double yd, FunctorType f, double x0, const double xtol, const double ftol, int& info) { double stepSize=1.0; const unsigned int maxIts = 10000; @@ -112,16 +120,18 @@ KOKKOS_INLINE_FUNCTION double InverseSingleBracket(double yd, FunctorType f, dou double xlb, xub; double ylb, yub; double xc, yc; + info = 0; - //std::cout << "xlb = " << xlb << std::endl; xlb = xub = x0; ylb = yub = f(xlb); - // Compute bounds - FindBracket(f, xlb, ylb, xub, yub, yd); - - if ((ylb>yd)||(yub(f, xlb, ylb, xub, yub, yd, bracket_info); + if((bracket_info<0)||(((ylb>yd)||(yub::quiet_NaN(); + } // Bracketed search const double k1 = 0.1; @@ -151,8 +161,7 @@ KOKKOS_INLINE_FUNCTION double InverseSingleBracket(double yd, FunctorType f, dou if(it>maxIts) - return std::numeric_limits::quiet_NaN(); - //ProcAgnosticError::error("InverseSingleBracket: Bracket search iterations exceeds maxIts"); + info = -1; return 0.5*(xub+xlb); } diff --git a/tests/Test_RootFinding.cpp b/tests/Test_RootFinding.cpp index b630ee35..67462966 100644 --- a/tests/Test_RootFinding.cpp +++ b/tests/Test_RootFinding.cpp @@ -30,52 +30,87 @@ TEST_CASE( "RootFindingUtils", "[RootFindingUtils]") { double yd = identity(xd); double xub = 2., yub = 2.; double xlb = 0., ylb = 0.; - FindBracket(identity, xlb, ylb, xub, yub, yd); + int info = 0; + FindBracket(identity, xlb, ylb, xub, yub, yd, info); CheckFoundBounds(identity, xlb, xd, xub, ylb, yd, yub); + CHECK(info==0); } SECTION("FindBracker sigmoid") { double xd = -0.5; double yd = sigmoid(xd); double xub = 2., yub = sigmoid(xub); double xlb = 0., ylb = sigmoid(xlb); - FindBracket(sigmoid, xlb, ylb, xub, yub, yd); + int info = 0; + FindBracket(sigmoid, xlb, ylb, xub, yub, yd, info); CheckFoundBounds(sigmoid, xlb, xd, xub, ylb, yd, yub); + CHECK(info==0); + } + SECTION("FindBracket flat") { + double xd = -1.1; + double yd = 0.0; + double xub = 2., yub = -1.0; + double xlb = 0., ylb = -1.0; + int info = 0; + auto f = [](double x){return -1.0;}; + FindBracket(f, xlb, ylb, xub, yub, yd, info); + CHECK(info==-1); } SECTION("Test Inverse Linear, low x0") { double xd = 0.5, yd = identity(xd); double x0 = 0.0, xtol = 1e-5, ftol = 1e-5; - double xd_found = InverseSingleBracket(yd, identity, x0, xtol, ftol); + int info = 0; + double xd_found = InverseSingleBracket(yd, identity, x0, xtol, ftol, info); CHECK( xd_found == Approx(xd).epsilon(2*xtol)); + CHECK(info==0); } SECTION("Test Inverse Linear, high x0") { double xd = 0.5, yd = identity(xd); double x0 = 1.0, xtol = 1e-5, ftol = 1e-5; - double xd_found = InverseSingleBracket(yd, identity, x0, xtol, ftol); + int info = 0; + double xd_found = InverseSingleBracket(yd, identity, x0, xtol, ftol, info); CHECK( xd_found == Approx(xd).epsilon(2*xtol)); + CHECK(info==0); + } + SECTION("Test Inverse Flat") { + double xd = 0.5, yd = -1.0; + double x0 = 0.0, xtol = 1e-5, ftol = 1e-5; + auto f = [](double x){return -1.0;}; + int info = 0; + double xd_found = InverseSingleBracket(yd, f, x0, xtol, ftol, info); + CHECK( std::isnan(xd_found)); + CHECK(info==-2); } SECTION("Test Inverse Sigmoid, low x0") { double xd = 0.5, yd = sigmoid(xd); double x0 = 0.0, xtol = 1e-5, ftol = 1e-5; - double xd_found = InverseSingleBracket(yd, sigmoid, x0, xtol, ftol); + int info = 0; + double xd_found = InverseSingleBracket(yd, sigmoid, x0, xtol, ftol, info); CHECK( xd_found == Approx(xd).epsilon(2*xtol)); + CHECK(info==0); } SECTION("Test Inverse Sigmoid, high x0") { double xd = 0.5, yd = sigmoid(xd); double x0 = 1.0, xtol = 1e-5, ftol = 1e-5; - double xd_found = InverseSingleBracket(yd, sigmoid, x0, xtol, ftol); + int info; + double xd_found = InverseSingleBracket(yd, sigmoid, x0, xtol, ftol, info); CHECK( xd_found == Approx(xd).epsilon(2*xtol)); + CHECK(info==0); } SECTION("Test Inverse Sigmoid Combo, low x0") { double xd = 0.5, yd = sigmoid_combo(xd); double x0 = -5.0, xtol = 1e-5, ftol = 1e-5; - double xd_found = InverseSingleBracket(yd, sigmoid_combo, x0, xtol, ftol); + int info; + double xd_found = InverseSingleBracket(yd, sigmoid_combo, x0, xtol, ftol, info); CHECK( xd_found == Approx(xd).epsilon(2*xtol)); + CHECK(info==0); } SECTION("Test Inverse Sigmoid Combo, high x0") { double xd = 0.5, yd = sigmoid_combo(xd); double x0 = 5.0, xtol = 1e-5, ftol = 1e-5; - double xd_found = InverseSingleBracket(yd, sigmoid_combo, x0, xtol, ftol); + int info; + double xd_found = InverseSingleBracket(yd, sigmoid_combo, x0, xtol, ftol, info); CHECK( xd_found == Approx(xd).epsilon(2*xtol)); + CHECK(info==0); } } \ No newline at end of file